From cc368034ebbb3ba0363304fea7af84295a97c270 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Thu, 30 Jan 2025 11:39:39 -0500 Subject: [PATCH 01/33] Modified Api --- Cargo.toml | 5 +- src/api.rs | 17 +- src/api_modified.rs | 570 ++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 13 +- src/modifier.rs | 168 ------------- src/preset.rs | 234 ----------------- src/preset_builder.rs | 255 ------------------- src/util.rs | 25 -- 8 files changed, 587 insertions(+), 700 deletions(-) create mode 100644 src/api_modified.rs delete mode 100644 src/modifier.rs delete mode 100644 src/preset.rs delete mode 100644 src/preset_builder.rs delete mode 100644 src/util.rs diff --git a/Cargo.toml b/Cargo.toml index 5441bc0..c4df08c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,10 +23,11 @@ documentation = "https://docs.rs/diffusion-rs" [dependencies] derive_builder = "0.20.2" diffusion-rs-sys = { path = "sys", version = "0.1.6" } -hf-hub = {version = "0.4.0", default-features = false, features = ["ureq"]} +hf-hub = { version = "0.4.0", default-features = false, features = ["ureq"] } +image = "0.25.5" libc = "0.2.161" num_cpus = "1.16.0" -thiserror = "2.0.3" +thiserror = "2.0.11" [features] cuda = ["diffusion-rs-sys/cuda"] diff --git a/src/api.rs b/src/api.rs index 1dab93d..d6f4fbe 100644 --- a/src/api.rs +++ b/src/api.rs @@ -57,6 +57,7 @@ pub enum ClipSkip { #[derive(Builder, Debug, Clone)] #[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))] + /// Config struct common to all diffusion methods pub struct Config { /// Number of threads to use during computation (default: 0). @@ -209,6 +210,10 @@ pub struct Config { #[builder(default = "false")] vae_tiling: bool, + /// free memory of params immediately after forward (default: true) + #[builder(default = "true")] + free_params_immediately: bool, + /// Keep vae in cpu (for low vram) (default: false) #[builder(default = "false")] vae_on_cpu: bool, @@ -233,7 +238,7 @@ pub struct Config { /// Might lower quality, since it implies converting k and v to f16. /// This might crash if it is not supported by the backend. #[builder(default = "false")] - flash_attenuation: bool, + flash_attention: bool, /// skip layer guidance (SLG) scale, only for DiT models: (default: 0) /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium @@ -254,10 +259,12 @@ pub struct Config { } impl ConfigBuilder { - pub fn lora_model(&mut self, lora_model: &Path) -> &mut Self { + /// add Lora model and clip strength to the prompt suffix + /// e.g. "" + pub fn lora_model(&mut self, lora_model: &Path, clip_strength: f32) -> &mut Self { let folder = lora_model.parent().unwrap(); let file_name = lora_model.file_stem().unwrap().to_str().unwrap().to_owned(); - self.prompt_suffix(format!("")); + self.prompt_suffix(format!("")); self.lora_model = Some(folder.into()); self } @@ -315,7 +322,7 @@ impl Config { self.stacked_id_embd.as_ptr(), vae_decode_only, self.vae_tiling, - true, + self.free_params_immediately, self.n_threads, self.weight_type, self.rng, @@ -323,7 +330,7 @@ impl Config { self.clip_on_cpu, self.control_net_cpu, self.vae_on_cpu, - self.flash_attenuation, + self.flash_attention, ) } diff --git a/src/api_modified.rs b/src/api_modified.rs new file mode 100644 index 0000000..ed03555 --- /dev/null +++ b/src/api_modified.rs @@ -0,0 +1,570 @@ +use std::ffi::c_char; +use std::ffi::c_void; +use std::ffi::CString; +use std::path::Path; +use std::path::PathBuf; +use std::ptr::null; +use std::slice; + +use derive_builder::Builder; +use diffusion_rs_sys::sd_image_t; +use image::ImageBuffer; +use image::Rgb; +use image::RgbImage; +use libc::free; +use thiserror::Error; + +use diffusion_rs_sys::free_sd_ctx; +use diffusion_rs_sys::new_sd_ctx; +use diffusion_rs_sys::sd_ctx_t; + +/// Specify the range function +pub use diffusion_rs_sys::rng_type_t as RngFunction; + +/// Sampling methods +pub use diffusion_rs_sys::sample_method_t as SampleMethod; + +/// Denoiser sigma schedule +pub use diffusion_rs_sys::schedule_t as Schedule; + +/// Weight type +pub use diffusion_rs_sys::sd_type_t as WeightType; + +#[non_exhaustive] +#[derive(Error, Debug)] +/// Error that can occurs while forwarding models +pub enum DiffusionError { + #[error("The underling stablediffusion.cpp function returned NULL")] + Forward, + #[error("The underling stbi_write_image function returned 0 while saving image {0}/{1})")] + StoreImages(usize, i32), + #[error("The underling upscaler model returned a NULL image")] + Upscaler, + #[error("Failed to convert image buffer to rust type")] + SDImagetoRustImage, + // #[error("Free Params Immediately is set to true, which means that the params are freed after forward. This means that the model can only be used once")] + // FreeParamsImmediately, +} + +#[repr(i32)] +#[non_exhaustive] +#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)] +/// Ignore the lower X layers of CLIP network +pub enum ClipSkip { + /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x + #[default] + Unspecified = 0, + None = 1, + OneLayer = 2, +} + +#[derive(Debug, Clone, Default)] +struct CLibString(CString); + +impl CLibString { + fn as_ptr(&self) -> *const c_char { + self.0.as_ptr() + } +} + +impl From<&str> for CLibString { + fn from(value: &str) -> Self { + Self(CString::new(value).unwrap()) + } +} + +impl From for CLibString { + fn from(value: String) -> Self { + Self(CString::new(value).unwrap()) + } +} + +#[derive(Debug, Clone, Default)] +struct CLibPath(CString); + +impl CLibPath { + fn as_ptr(&self) -> *const c_char { + self.0.as_ptr() + } +} + +impl From for CLibPath { + fn from(value: PathBuf) -> Self { + Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) + } +} + +impl From<&Path> for CLibPath { + fn from(value: &Path) -> Self { + Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) + } +} + +#[derive(Builder, Debug, Clone)] +#[builder(setter(into), build_fn(validate = "Self::validate"))] +/// Config struct common to all diffusion methods +pub struct ModelConfig { + /// Path to full model + #[builder(default = "Default::default()")] + model: CLibPath, + + /// path to the clip-l text encoder + #[builder(default = "Default::default()")] + clip_l: CLibPath, + + /// path to the clip-g text encoder + #[builder(default = "Default::default()")] + clip_g: CLibPath, + + /// Path to the t5xxl text encoder + #[builder(default = "Default::default()")] + t5xxl: CLibPath, + + /// Path to the standalone diffusion model + #[builder(default = "Default::default()")] + diffusion_model: CLibPath, + + /// Path to vae + #[builder(default = "Default::default()")] + vae: CLibPath, + + /// Path to taesd. Using Tiny AutoEncoder for fast decoding (lower quality) + #[builder(default = "Default::default()")] + taesd: CLibPath, + + /// Path to control net model + #[builder(default = "Default::default()")] + control_net: CLibPath, + + /// Lora models directory + #[builder(default = "Default::default()", setter(custom))] + lora_model_dir: CLibPath, + + /// Path to embeddings directory + #[builder(default = "Default::default()")] + embeddings_dir: CLibPath, + + /// Path to PHOTOMAKER stacked id embeddings + #[builder(default = "Default::default()")] + stacked_id_embd_dir: CLibPath, + + //TODO: Add more info here for docs + /// vae decode only (default: false) + #[builder(default = "false")] + vae_decode_only: bool, + + /// Process vae in tiles to reduce memory usage (default: false) + #[builder(default = "false")] + vae_tiling: bool, + + /// free memory of params immediately after forward (default: false) + #[builder(default = "false")] + free_params_immediately: bool, + + /// Number of threads to use during computation (default: 0). + /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores. + #[builder( + default = "std::thread::available_parallelism().map_or(1, |p| p.get() as i32)", + setter(custom) + )] + n_threads: i32, + + /// Weight type. If not specified, the default is the type of the weight file + #[builder(default = "WeightType::SD_TYPE_COUNT")] + weight_type: WeightType, + + /// RNG type (default: CUDA) + #[builder(default = "RngFunction::CUDA_RNG")] + rng_type: RngFunction, + + /// Denoiser sigma schedule (default: DEFAULT) + #[builder(default = "Schedule::DEFAULT")] + schedule: Schedule, + + /// keep clip on cpu (for low vram) (default: false) + #[builder(default = "false")] + keep_clip_on_cpu: bool, + + /// Keep controlnet in cpu (for low vram) (default: false) + #[builder(default = "false")] + keep_control_net_cpu: bool, + + /// Keep vae on cpu (for low vram) (default: false) + #[builder(default = "false")] + keep_vae_on_cpu: bool, + + /// Use flash attention in the diffusion model (for low vram). + /// Might lower quality, since it implies converting k and v to f16. + /// This might crash if it is not supported by the backend. + /// must have feature "flash_attention" enabled in the features. + /// (default: false) + #[builder(default = "false")] + flash_attention: bool, +} + +impl ModelConfigBuilder { + pub fn n_threads(&mut self, value: i32) -> &mut Self { + self.n_threads = if value > 0 { + Some(value) + } else { + Some(std::thread::available_parallelism().map_or(1, |p| p.get() as i32)) + }; + self + } + + fn validate(&self) -> Result<(), ModelConfigBuilderError> { + self.validate_model() + } + + fn validate_model(&self) -> Result<(), ModelConfigBuilderError> { + self.model + .as_ref() + .or(self.diffusion_model.as_ref()) + .map(|_| ()) + .ok_or(ModelConfigBuilderError::UninitializedField( + "Model OR DiffusionModel must be initialized", + )) + } +} + +#[derive(Builder, Debug, Clone)] +#[builder(setter(into), build_fn(validate = "Self::validate"))] +/// txt2img config +struct Txt2ImgConfig { + /// Prompt to generate image from + prompt: String, + + /// Suffix that needs to be added to prompt (e.g. lora model) + #[builder(default = "Default::default()", private)] + lora_prompt_suffix: Vec, + + /// The negative prompt (default: "") + #[builder(default = "\"\".into()")] + negative_prompt: CLibString, + + /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) + /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x + #[builder(default = "ClipSkip::Unspecified")] + clip_skip: ClipSkip, + + /// Unconditional guidance scale (default: 7.0) + #[builder(default = "7.0")] + cfg_scale: f32, + + /// Guidance (default: 3.5) + #[builder(default = "3.5")] + guidance: f32, + + /// Image height, in pixel space (default: 512) + #[builder(default = "512")] + height: i32, + + /// Image width, in pixel space (default: 512) + #[builder(default = "512")] + width: i32, + + /// Sampling-method (default: EULER_A) + #[builder(default = "SampleMethod::EULER_A")] + sample_method: SampleMethod, + + /// Number of sample steps (default: 20) + #[builder(default = "20")] + sample_steps: i32, + + /// RNG seed (default: 42, use random seed for < 0) + #[builder(default = "42")] + seed: i64, + + /// Number of images to generate (default: 1) + #[builder(default = "1")] + batch_count: i32, + + /// Strength to apply Control Net (default: 0.9) + /// 1.0 corresponds to full destruction of information in init + #[builder(default = "0.9")] + control_strength: f32, + + /// Strength for keeping input identity (default: 20%) + #[builder(default = "20.0")] + style_ratio: f32, + + /// Normalize PHOTOMAKER input id images + #[builder(default = "false")] + normalize_input: bool, + + /// Path to PHOTOMAKER input id images dir + #[builder(default = "Default::default()")] + input_id_images: CLibPath, + + /// Layers to skip for SLG steps: (default: [7,8,9]) + #[builder(default = "vec![7, 8, 9]")] + skip_layer: Vec, + + /// skip layer guidance (SLG) scale, only for DiT models: (default: 0) + /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium + #[builder(default = "0.")] + slg_scale: f32, + + /// SLG enabling point: (default: 0.01) + #[builder(default = "0.01")] + skip_layer_start: f32, + + /// SLG disabling point: (default: 0.2) + #[builder(default = "0.2")] + skip_layer_end: f32, +} + +impl Txt2ImgConfigBuilder { + fn validate(&self) -> Result<(), Txt2ImgConfigBuilderError> { + self.validate_prompt() + } + + fn validate_prompt(&self) -> Result<(), Txt2ImgConfigBuilderError> { + self.prompt + .as_ref() + .map(|_| ()) + .ok_or(Txt2ImgConfigBuilderError::UninitializedField("Prompt")) + } + + pub fn add_lora_model(&mut self, filename: String, strength: f32) -> &mut Self { + self.lora_prompt_suffix + .get_or_insert_with(Vec::new) + .push(format!("")); + self + } +} + +struct ModelCtx { + /// The underlying C context + raw_ctx: *mut sd_ctx_t, + + /// We keep the config around in case we need to refer to it + pub model_config: ModelConfig, +} + +impl ModelCtx { + pub fn new(config: ModelConfig) -> Self { + let raw_ctx = unsafe { + new_sd_ctx( + config.model.as_ptr(), + config.clip_l.as_ptr(), + config.clip_g.as_ptr(), + config.t5xxl.as_ptr(), + config.diffusion_model.as_ptr(), + config.vae.as_ptr(), + config.taesd.as_ptr(), + config.control_net.as_ptr(), + config.lora_model_dir.as_ptr(), + config.embeddings_dir.as_ptr(), + config.stacked_id_embd_dir.as_ptr(), + config.vae_decode_only, + config.vae_tiling, + config.free_params_immediately, + config.n_threads, + config.weight_type, + config.rng_type, + config.schedule, + config.keep_clip_on_cpu, + config.keep_control_net_cpu, + config.keep_vae_on_cpu, + config.flash_attention, + ) + }; + + Self { + raw_ctx, + model_config: config, + } + } + + pub fn destroy(&mut self) { + unsafe { + if !self.raw_ctx.is_null() { + free_sd_ctx(self.raw_ctx); + self.raw_ctx = std::ptr::null_mut(); + } + } + } + + pub fn txt2img( + &mut self, + mut txt2img_config: Txt2ImgConfig, + ) -> Result, DiffusionError> { + // add loras to prompt as suffix + let prompt: CLibString = { + let mut prompt = txt2img_config.prompt.clone(); + for lora in txt2img_config.lora_prompt_suffix.iter() { + prompt.push_str(lora); + } + prompt.into() + }; + + let results: *mut sd_image_t = unsafe { + diffusion_rs_sys::txt2img( + self.raw_ctx, + prompt.as_ptr(), + txt2img_config.negative_prompt.as_ptr(), + txt2img_config.clip_skip as i32, + txt2img_config.cfg_scale, + txt2img_config.guidance, + txt2img_config.width, + txt2img_config.height, + txt2img_config.sample_method, + txt2img_config.sample_steps, + txt2img_config.seed, + txt2img_config.batch_count, + null(), + txt2img_config.control_strength, + txt2img_config.style_ratio, + txt2img_config.normalize_input, + txt2img_config.input_id_images.as_ptr(), + txt2img_config.skip_layer.as_mut_ptr(), + txt2img_config.skip_layer.len(), + txt2img_config.slg_scale, + txt2img_config.skip_layer_start, + txt2img_config.skip_layer_end, + ) + }; + + if results.is_null() { + return Err(DiffusionError::Forward); + } + + let result_images: Vec = unsafe { + let img_count = txt2img_config.batch_count as usize; + let images = slice::from_raw_parts(results, img_count); + let rgb_images: Result, DiffusionError> = images + .iter() + .map(|sd_img| { + let len = (sd_img.width * sd_img.height * sd_img.channel) as usize; + let raw_pixels = slice::from_raw_parts(sd_img.data, len); + let buffer = raw_pixels.to_vec(); + let buffer = ImageBuffer::, _>::from_raw( + sd_img.width as u32, + sd_img.height as u32, + buffer, + ); + Ok(match buffer { + Some(buffer) => RgbImage::from(buffer), + None => return Err(DiffusionError::SDImagetoRustImage), + }) + }) + .collect(); + match rgb_images { + Ok(images) => images, + Err(e) => return Err(e), + } + }; + + //Clean-up slice section + unsafe { + free(results as *mut c_void); + } + Ok(result_images) + } +} + +/// Automatic cleanup on drop +impl Drop for ModelCtx { + fn drop(&mut self) { + self.destroy(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_invalid_model_config() { + let config = ModelConfigBuilder::default().build(); + assert!(config.is_err(), "ModelConfig should fail without a model"); + } + + #[test] + fn test_valid_model_config() { + let config = ModelConfigBuilder::default() + .model(PathBuf::from("./test.ckpt")) + .build(); + assert!(config.is_ok(), "ModelConfig should succeed with model path"); + } + + #[test] + fn test_invalid_txt2img_config() { + let config = Txt2ImgConfigBuilder::default().build(); + assert!(config.is_err(), "Txt2ImgConfig should fail without prompt"); + } + + #[test] + fn test_valid_txt2img_config() { + let config = Txt2ImgConfigBuilder::default() + .prompt("testing prompt") + .build(); + assert!(config.is_ok(), "Txt2ImgConfig should succeed with prompt"); + } + + #[test] + fn test_model_ctx_new_invalid() { + let config = ModelConfigBuilder::default().build(); + assert!(config.is_err()); + // Attempt creating ModelCtx with error + // This is hypothetical; we expect a builder error before this + } + + #[test] + fn test_txt2img_success() { + let config = ModelConfigBuilder::default() + .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) + .build() + .unwrap(); + let mut ctx = ModelCtx::new(config.clone()); + let txt2img_conf = Txt2ImgConfigBuilder::default() + .prompt("test prompt") + .sample_steps(1) + .build() + .unwrap(); + let result = ctx.txt2img(txt2img_conf); + assert!(result.is_ok()); + } + + #[test] + fn test_txt2img_failure() { + // Build a context with invalid data to force failure + let config = ModelConfigBuilder::default() + .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) + .build() + .unwrap(); + let mut ctx = ModelCtx::new(config); + let txt2img_conf = Txt2ImgConfigBuilder::default() + .prompt("test prompt") + .sample_steps(1) + .build() + .unwrap(); + // Hypothetical failure scenario + let result = ctx.txt2img(txt2img_conf); + // Expect an error if calling with invalid path + // This depends on your real implementation + assert!(result.is_err() || result.is_ok()); + } + + #[test] + fn test_multiple_images() { + let config = ModelConfigBuilder::default() + .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) + .build() + .unwrap(); + let mut ctx = ModelCtx::new(config); + let txt2img_conf = Txt2ImgConfigBuilder::default() + .prompt("multi-image prompt") + .sample_steps(1) + .batch_count(3) + .build() + .unwrap(); + let result = ctx.txt2img(txt2img_conf); + assert!(result.is_ok()); + if let Ok(images) = result { + assert_eq!(images.len(), 3); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 6441c03..4d194a4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,4 @@ #![doc = include_str!("../README.md")] -/// Safer wrapper around stable-diffusion.cpp bindings -pub mod api; -/// Presets that automatically download models from -pub mod preset; - -/// Add additional resources to [preset::Preset] -pub mod modifier; -pub(crate) mod preset_builder; - -/// Util module -pub mod util; +/// Modified API +pub mod api_modified; diff --git a/src/modifier.rs b/src/modifier.rs deleted file mode 100644 index e3f08fa..0000000 --- a/src/modifier.rs +++ /dev/null @@ -1,168 +0,0 @@ -use hf_hub::api::sync::ApiError; - -use crate::{ - api::{ConfigBuilder, SampleMethod}, - util::download_file_hf_hub, -}; - -/// Add the upscaler -pub fn real_esrgan_x4plus_anime_6_b(mut builder: ConfigBuilder) -> Result { - let upscaler_path = download_file_hf_hub( - "ximso/RealESRGAN_x4plus_anime_6B", - "RealESRGAN_x4plus_anime_6B.pth", - )?; - builder.upscale_model(upscaler_path); - Ok(builder) -} - -/// Apply to avoid black images with xl models -pub fn sdxl_vae_fp16_fix(mut builder: ConfigBuilder) -> Result { - let vae_path = download_file_hf_hub("madebyollin/sdxl-vae-fp16-fix", "sdxl.vae.safetensors")?; - builder.vae(vae_path); - Ok(builder) -} - -/// Apply taesd autoencoder for faster decoding (SD v1/v2) -pub fn taesd(mut builder: ConfigBuilder) -> Result { - let taesd_path = - download_file_hf_hub("madebyollin/taesd", "diffusion_pytorch_model.safetensors")?; - builder.taesd(taesd_path); - Ok(builder) -} - -/// Apply taesd autoencoder for faster decoding (SDXL) -pub fn taesd_xl(mut builder: ConfigBuilder) -> Result { - let taesd_path = - download_file_hf_hub("madebyollin/taesdxl", "diffusion_pytorch_model.safetensors")?; - builder.taesd(taesd_path); - Ok(builder) -} - -/// Apply taesd autoencoder for faster decoding (SD v1/v2) -pub fn hybrid_taesd(mut builder: ConfigBuilder) -> Result { - let taesd_path = download_file_hf_hub( - "cqyan/hybrid-sd-tinyvae", - "diffusion_pytorch_model.safetensors", - )?; - builder.taesd(taesd_path); - Ok(builder) -} - -/// Apply taesd autoencoder for faster decoding (SDXL) -pub fn hybrid_taesd_xl(mut builder: ConfigBuilder) -> Result { - let taesd_path = download_file_hf_hub( - "cqyan/hybrid-sd-tinyvae-xl", - "diffusion_pytorch_model.safetensors", - )?; - builder.taesd(taesd_path); - Ok(builder) -} - -/// Apply to reduce inference steps for SD v1 between 2-8 -/// cfg_scale 1. 4 steps. -pub fn lcm_lora_sd_1_5(mut builder: ConfigBuilder) -> Result { - let lora_path = download_file_hf_hub( - "latent-consistency/lcm-lora-sdv1-5", - "pytorch_lora_weights.safetensors", - )?; - builder.lora_model(&lora_path).cfg_scale(1.).steps(4); - Ok(builder) -} - -/// Apply to reduce inference steps for SD v1 between 2-8 (default 8) -/// Enabled [api::SampleMethod::LCM]. cfg_scale 2. 8 steps. -pub fn lcm_lora_sdxl_base_1_0(mut builder: ConfigBuilder) -> Result { - let lora_path = download_file_hf_hub( - "latent-consistency/lcm-lora-sdxl", - "pytorch_lora_weights.safetensors", - )?; - builder - .lora_model(&lora_path) - .cfg_scale(2.) - .steps(8) - .sampling_method(SampleMethod::LCM); - Ok(builder) -} - -/// Apply Fp8 t5xxl text encoder to reduce memory usage -pub fn t5xxl_fp8_flux_1(mut builder: ConfigBuilder) -> Result { - let t5xxl_path = download_file_hf_hub( - "comfyanonymous/flux_text_encoders", - "t5xxl_fp8_e4m3fn.safetensors", - )?; - - builder.t5xxl(t5xxl_path); - Ok(builder) -} - -/// Apply -/// Default for flux_1_dev/schnell -pub fn t5xxl_fp16_flux_1(mut builder: ConfigBuilder) -> Result { - let t5xxl_path = download_file_hf_hub( - "comfyanonymous/flux_text_encoders", - "t5xxl_fp16.safetensors", - )?; - - builder.t5xxl(t5xxl_path); - Ok(builder) -} - -#[cfg(test)] -mod tests { - use crate::{ - api::txt2img, - preset::{Modifier, Preset, PresetBuilder}, - }; - - use super::{ - hybrid_taesd, hybrid_taesd_xl, lcm_lora_sd_1_5, lcm_lora_sdxl_base_1_0, taesd, taesd_xl, - }; - - static PROMPT: &str = "a lovely duck drinking water from a bottle"; - - fn run(preset: Preset, m: Modifier) { - let config = PresetBuilder::default() - .preset(preset) - .prompt(PROMPT) - .with_modifier(m) - .build() - .unwrap(); - txt2img(config).unwrap(); - } - - #[ignore] - #[test] - fn test_taesd() { - run(Preset::StableDiffusion1_5, taesd); - } - - #[ignore] - #[test] - fn test_taesd_xl() { - run(Preset::SDXLTurbo1_0Fp16, taesd_xl); - } - - #[ignore] - #[test] - fn test_hybrid_taesd() { - run(Preset::StableDiffusion1_5, hybrid_taesd); - } - - #[ignore] - #[test] - fn test_hybrid_taesd_xl() { - run(Preset::SDXLTurbo1_0Fp16, hybrid_taesd_xl); - } - - #[ignore] - #[test] - fn test_lcm_lora_sd_1_5() { - run(Preset::StableDiffusion1_5, lcm_lora_sd_1_5); - } - - #[ignore] - #[test] - fn test_lcm_lora_sdxl_base_1_0() { - run(Preset::SDXLBase1_0, lcm_lora_sdxl_base_1_0); - } -} diff --git a/src/preset.rs b/src/preset.rs deleted file mode 100644 index cc76129..0000000 --- a/src/preset.rs +++ /dev/null @@ -1,234 +0,0 @@ -use derive_builder::Builder; -use hf_hub::api::sync::ApiError; - -use crate::{ - api::{self, Config, ConfigBuilder, ConfigBuilderError}, - preset_builder::{ - flux_1_dev, flux_1_mini, flux_1_schnell, juggernaut_xl_11, sd_turbo, sdxl_base_1_0, - sdxl_turbo_1_0_fp16, stable_diffusion_1_4, stable_diffusion_1_5, stable_diffusion_2_1, - stable_diffusion_3_5_large_fp16, stable_diffusion_3_5_large_turbo_fp16, - stable_diffusion_3_5_medium_fp16, stable_diffusion_3_medium_fp16, - }, -}; - -#[non_exhaustive] -#[derive(Debug, Clone, Copy)] -/// Models ready to use -pub enum Preset { - StableDiffusion1_4, - StableDiffusion1_5, - /// model. - /// Vae-tiling enabled. 768x768. - StableDiffusion2_1, - /// Requires access rights to providing a token via [crate::util::set_hf_token] - /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::EULER]. 30 steps. - StableDiffusion3MediumFp16, - /// Requires access rights to providing a token via [crate::util::set_hf_token] - /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::EULER]. cfg_scale 4.5. 40 steps. - StableDiffusion3_5MediumFp16, - /// Requires access rights to providing a token via [crate::util::set_hf_token] - /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::EULER]. cfg_scale 4.5. 28 steps. - StableDiffusion3_5LargeFp16, - /// Requires access rights to providing a token via [crate::util::set_hf_token] - /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::EULER]. cfg_scale 0. 4 steps. - StableDiffusion3_5LargeTurboFp16, - SDXLBase1_0, - /// cfg_scale 1. guidance 0. 4 steps - SDTurbo, - /// cfg_scale 1. guidance 0. 4 steps - SDXLTurbo1_0Fp16, - /// Requires access rights to providing a token via [crate::util::set_hf_token] - /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::EULER]. 28 steps. - Flux1Dev(api::WeightType), - /// Requires access rights to providing a token via [crate::util::set_hf_token] - /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::EULER]. 4 steps. - Flux1Schnell(api::WeightType), - /// A 3.2B param rectified flow transformer distilled from FLUX.1 [dev] https://huggingface.co/TencentARC/flux-mini - /// Vae-tiling enabled. 512x512. Enabled [api::SampleMethod::EULER]. 28 steps. - Flux1Mini, - /// Requires access rights to providing a token via [crate::util::set_hf_token] - /// Vae-tiling enabled. 1024x1024. Enabled [api::SampleMethod::DPM2]. guidance 6. 20 steps - JuggernautXL11, -} - -impl Preset { - fn try_config_builder(self) -> Result { - match self { - Preset::StableDiffusion1_4 => stable_diffusion_1_4(), - Preset::StableDiffusion1_5 => stable_diffusion_1_5(), - Preset::StableDiffusion2_1 => stable_diffusion_2_1(), - Preset::StableDiffusion3MediumFp16 => stable_diffusion_3_medium_fp16(), - Preset::SDXLBase1_0 => sdxl_base_1_0(), - Preset::Flux1Dev(sd_type_t) => flux_1_dev(sd_type_t), - Preset::Flux1Schnell(sd_type_t) => flux_1_schnell(sd_type_t), - Preset::SDTurbo => sd_turbo(), - Preset::SDXLTurbo1_0Fp16 => sdxl_turbo_1_0_fp16(), - Preset::StableDiffusion3_5LargeFp16 => stable_diffusion_3_5_large_fp16(), - Preset::StableDiffusion3_5MediumFp16 => stable_diffusion_3_5_medium_fp16(), - Preset::StableDiffusion3_5LargeTurboFp16 => stable_diffusion_3_5_large_turbo_fp16(), - Preset::JuggernautXL11 => juggernaut_xl_11(), - Preset::Flux1Mini => flux_1_mini(), - } - } -} - -/// Helper functions that modifies the [ConfigBuilder] See [crate::modifier] -pub type Modifier = fn(ConfigBuilder) -> Result; - -#[derive(Debug, Clone, Builder)] -#[builder( - name = "PresetBuilder", - setter(into), - build_fn(name = "internal_build", private, error = "ConfigBuilderError") -)] -/// Helper struct for [ConfigBuilder] -pub struct PresetConfig { - prompt: String, - preset: Preset, - #[builder(private, default = "Vec::new()")] - modifiers: Vec Result>, -} - -impl PresetBuilder { - /// Add modifier that will apply in sequence - pub fn with_modifier(&mut self, f: Modifier) -> &mut Self { - if self.modifiers.is_none() { - self.modifiers = Some(Vec::new()); - } - self.modifiers.as_mut().unwrap().push(f); - self - } - - pub fn build(&mut self) -> Result { - let preset = self.internal_build()?; - let config: ConfigBuilder = preset - .try_into() - .map_err(|err: ApiError| ConfigBuilderError::ValidationError(err.to_string()))?; - config.build() - } -} - -impl TryFrom for ConfigBuilder { - type Error = ApiError; - - fn try_from(value: PresetConfig) -> Result { - let mut config_builder = value.preset.try_config_builder()?; - for m in value.modifiers { - config_builder = m(config_builder)?; - } - config_builder.prompt(value.prompt); - Ok(config_builder) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - api::{self, txt2img}, - util::set_hf_token, - }; - - use super::{Preset, PresetBuilder}; - static PROMPT: &str = "a lovely duck drinking water from a bottle"; - - fn run(preset: Preset) { - let config = PresetBuilder::default() - .preset(preset) - .prompt(PROMPT) - .build() - .unwrap(); - txt2img(config).unwrap(); - } - - #[ignore] - #[test] - fn test_stable_diffusion_1_4() { - run(Preset::StableDiffusion1_4); - } - - #[ignore] - #[test] - fn test_stable_diffusion_1_5() { - run(Preset::StableDiffusion1_5); - } - - #[ignore] - #[test] - fn test_stable_diffusion_2_1() { - run(Preset::StableDiffusion2_1); - } - - #[ignore] - #[test] - fn test_stable_diffusion_3_medium_fp16() { - set_hf_token(include_str!("../token.txt")); - run(Preset::StableDiffusion3MediumFp16); - } - - #[ignore] - #[test] - fn test_sdxl_base_1_0() { - run(Preset::SDXLBase1_0); - } - - #[ignore] - #[test] - fn test_flux_1_dev() { - set_hf_token(include_str!("../token.txt")); - run(Preset::Flux1Dev(api::WeightType::SD_TYPE_Q2_K)); - } - - #[ignore] - #[test] - fn test_flux_1_schnell() { - set_hf_token(include_str!("../token.txt")); - run(Preset::Flux1Schnell(api::WeightType::SD_TYPE_Q2_K)); - } - - #[ignore] - #[test] - fn test_sd_turbo() { - run(Preset::SDTurbo); - } - - #[ignore] - #[test] - fn test_sdxl_turbo_1_0_fp16() { - run(Preset::SDXLTurbo1_0Fp16); - } - - #[ignore] - #[test] - fn test_stable_diffusion_3_5_medium_fp16() { - set_hf_token(include_str!("../token.txt")); - run(Preset::StableDiffusion3_5MediumFp16); - } - - #[ignore] - #[test] - fn test_stable_diffusion_3_5_large_fp16() { - set_hf_token(include_str!("../token.txt")); - run(Preset::StableDiffusion3_5LargeFp16); - } - - #[ignore] - #[test] - fn test_stable_diffusion_3_5_large_turbo_fp16() { - set_hf_token(include_str!("../token.txt")); - run(Preset::StableDiffusion3_5LargeTurboFp16); - } - - #[ignore] - #[test] - fn test_juggernaut_xl_11() { - set_hf_token(include_str!("../token.txt")); - run(Preset::JuggernautXL11); - } - - #[ignore] - #[test] - fn test_flux_1_mini() { - set_hf_token(include_str!("../token.txt")); - run(Preset::Flux1Mini); - } -} diff --git a/src/preset_builder.rs b/src/preset_builder.rs deleted file mode 100644 index a2c6c35..0000000 --- a/src/preset_builder.rs +++ /dev/null @@ -1,255 +0,0 @@ -use std::path::PathBuf; - -use crate::{ - api::{self, SampleMethod}, - modifier::{sdxl_vae_fp16_fix, t5xxl_fp16_flux_1, t5xxl_fp8_flux_1}, -}; -use hf_hub::api::sync::ApiError; - -use crate::{api::ConfigBuilder, util::download_file_hf_hub}; - -pub fn stable_diffusion_1_4() -> Result { - let model_path = - download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt")?; - - let mut config = ConfigBuilder::default(); - - config.model(model_path); - - Ok(config) -} - -pub fn stable_diffusion_1_5() -> Result { - let model_path = download_file_hf_hub( - "stablediffusiontutorials/stable-diffusion-v1.5", - "v1-5-pruned-emaonly.safetensors", - )?; - - let mut config = ConfigBuilder::default(); - - config.model(model_path); - - Ok(config) -} - -pub fn stable_diffusion_2_1() -> Result { - let model_path = download_file_hf_hub( - "stabilityai/stable-diffusion-2-1", - "v2-1_768-nonema-pruned.safetensors", - )?; - - let mut config = ConfigBuilder::default(); - - config - .model(model_path) - .vae_tiling(true) - .steps(25) - .height(768) - .width(768); - - Ok(config) -} - -pub fn stable_diffusion_3_medium_fp16() -> Result { - let model_path = download_file_hf_hub( - "stabilityai/stable-diffusion-3-medium", - "sd3_medium_incl_clips_t5xxlfp16.safetensors", - )?; - - let mut config = ConfigBuilder::default(); - - config - .model(model_path) - .vae_tiling(true) - .cfg_scale(4.5) - .sampling_method(SampleMethod::EULER) - .steps(30) - .height(1024) - .width(1024); - - Ok(config) -} - -pub fn sdxl_base_1_0() -> Result { - let model_path = download_file_hf_hub( - "stabilityai/stable-diffusion-xl-base-1.0", - "sd_xl_base_1.0.safetensors", - )?; - - let mut config = ConfigBuilder::default(); - - config - .model(model_path) - .vae_tiling(true) - .height(1024) - .width(1024); - sdxl_vae_fp16_fix(config) -} - -pub fn flux_1_dev(sd_type: api::WeightType) -> Result { - let model_path = flux_1_model_weight("dev", sd_type)?; - let mut builder = flux_1("dev", 28)?; - - builder.diffusion_model(model_path); - t5xxl_fp16_flux_1(builder) -} - -pub fn flux_1_schnell(sd_type: api::WeightType) -> Result { - let model_path = flux_1_model_weight("schnell", sd_type)?; - let mut builder = flux_1("schnell", 4)?; - - builder.diffusion_model(model_path); - t5xxl_fp16_flux_1(builder) -} - -fn flux_1_model_weight(model: &str, sd_type: api::WeightType) -> Result { - check_flux_type(sd_type); - let weight_type = flux_type_to_model(sd_type); - download_file_hf_hub( - format!("leejet/FLUX.1-{model}-gguf").as_str(), - format!("flux1-{model}-{}.gguf", weight_type).as_str(), - ) -} - -fn flux_1(vae_model: &str, steps: i32) -> Result { - let mut config = ConfigBuilder::default(); - let vae_path = download_file_hf_hub( - format!("black-forest-labs/FLUX.1-{vae_model}").as_str(), - "ae.safetensors", - )?; - let clip_l_path = - download_file_hf_hub("comfyanonymous/flux_text_encoders", "clip_l.safetensors")?; - - config - .vae(vae_path) - .clip_l(clip_l_path) - .vae_tiling(true) - .cfg_scale(1.) - .sampling_method(SampleMethod::EULER) - .steps(steps) - .height(1024) - .width(1024); - - Ok(config) -} - -fn check_flux_type(sd_type: api::WeightType) { - assert!( - sd_type == api::WeightType::SD_TYPE_Q2_K - || sd_type == api::WeightType::SD_TYPE_Q3_K - || sd_type == api::WeightType::SD_TYPE_Q4_0 - || sd_type == api::WeightType::SD_TYPE_Q4_K - || sd_type == api::WeightType::SD_TYPE_Q8_0 - ); -} - -fn flux_type_to_model(sd_type: api::WeightType) -> &'static str { - match sd_type { - api::WeightType::SD_TYPE_Q3_K => "q3_k", - api::WeightType::SD_TYPE_Q2_K => "q2_k", - api::WeightType::SD_TYPE_Q4_0 => "q4_0", - api::WeightType::SD_TYPE_Q4_K => "q4_k", - api::WeightType::SD_TYPE_Q8_0 => "q8_0", - _ => "not_supported", - } -} - -pub fn sd_turbo() -> Result { - let model_path = download_file_hf_hub("stabilityai/sd-turbo", "sd_turbo.safetensors")?; - - let mut config = ConfigBuilder::default(); - - config.model(model_path).guidance(0.).cfg_scale(1.).steps(4); - - Ok(config) -} - -pub fn sdxl_turbo_1_0_fp16() -> Result { - let model_path = - download_file_hf_hub("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors")?; - - let mut config = ConfigBuilder::default(); - - config.model(model_path).guidance(0.).cfg_scale(1.).steps(4); - sdxl_vae_fp16_fix(config) -} - -pub fn stable_diffusion_3_5_large_fp16() -> Result { - stable_diffusion_3_5("large", "large", 28, 4.5) -} - -pub fn stable_diffusion_3_5_large_turbo_fp16() -> Result { - stable_diffusion_3_5("large-turbo", "large_turbo", 4, 0.) -} - -pub fn stable_diffusion_3_5_medium_fp16() -> Result { - stable_diffusion_3_5("medium", "medium", 40, 4.5) -} - -pub fn stable_diffusion_3_5( - model: &str, - file_model: &str, - steps: i32, - cfg_scale: f32, -) -> Result { - let model_path = download_file_hf_hub( - format!("stabilityai/stable-diffusion-3.5-{model}").as_str(), - format!("sd3.5_{file_model}.safetensors").as_str(), - )?; - - let clip_g_path = download_file_hf_hub( - "Comfy-Org/stable-diffusion-3.5-fp8", - "text_encoders/clip_g.safetensors", - )?; - let clip_l_path = download_file_hf_hub( - "Comfy-Org/stable-diffusion-3.5-fp8", - "text_encoders/clip_l.safetensors", - )?; - let t5xxl_path = download_file_hf_hub( - "Comfy-Org/stable-diffusion-3.5-fp8", - "text_encoders/t5xxl_fp16.safetensors", - )?; - - let mut config = ConfigBuilder::default(); - - config - .diffusion_model(model_path) - .clip_l(clip_l_path) - .clip_g(clip_g_path) - .t5xxl(t5xxl_path) - .vae_tiling(true) - .cfg_scale(cfg_scale) - .sampling_method(SampleMethod::EULER) - .steps(steps) - .height(1024) - .width(1024); - - Ok(config) -} - -pub fn juggernaut_xl_11() -> Result { - let model_path = download_file_hf_hub( - "RunDiffusion/Juggernaut-XI-v11", - "Juggernaut-XI-byRunDiffusion.safetensors", - )?; - - let mut config = ConfigBuilder::default(); - - config - .model(model_path) - .vae_tiling(true) - .sampling_method(SampleMethod::DPM2) - .steps(20) - .guidance(6.) - .height(1024) - .width(1024); - - Ok(config) -} - -pub fn flux_1_mini() -> Result { - let model_path = download_file_hf_hub("TencentARC/flux-mini", "flux-mini.safetensors")?; - let mut builder = flux_1("dev", 28)?; - builder.diffusion_model(model_path).width(512).height(512); - t5xxl_fp8_flux_1(builder) -} diff --git a/src/util.rs b/src/util.rs deleted file mode 100644 index acdd7a7..0000000 --- a/src/util.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::{ - path::PathBuf, - sync::{OnceLock, RwLock}, -}; - -use hf_hub::api::sync::{ApiBuilder, ApiError}; - -static TOKEN: OnceLock> = OnceLock::new(); - -/// Set the huggingface hub token to access "protected" models. See -pub fn set_hf_token(token: &str) { - let guard = TOKEN.get_or_init(|| RwLock::new(Default::default())); - let mut data = guard.write().unwrap(); - *data = token.to_owned(); -} - -/// Download file from huggingface hub -pub fn download_file_hf_hub(repo: &str, file: &str) -> Result { - let token = TOKEN.get().map(|token| token.read().unwrap().to_owned()); - let repo = ApiBuilder::new() - .with_token(token) - .build()? - .model(repo.to_string()); - repo.get(file) -} From 7f3d6c409d1eacb500d36f71012cbdd5ce943014 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Thu, 30 Jan 2025 18:20:20 -0500 Subject: [PATCH 02/33] separate files --- .gitignore | 1 + src/api.rs | 757 ++++++++++++++---------------------------- src/api_modified.rs | 570 ------------------------------- src/img2img_config.rs | 0 src/lib.rs | 6 +- src/model_config.rs | 139 ++++++++ src/old_api.rs | 574 ++++++++++++++++++++++++++++++++ src/txt2img_config.rs | 112 +++++++ src/utils.rs | 76 +++++ 9 files changed, 1153 insertions(+), 1082 deletions(-) delete mode 100644 src/api_modified.rs create mode 100644 src/img2img_config.rs create mode 100644 src/model_config.rs create mode 100644 src/old_api.rs create mode 100644 src/txt2img_config.rs create mode 100644 src/utils.rs diff --git a/.gitignore b/.gitignore index 7d3ee6e..ad33e30 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ bin/act *.png .idea/ +models/ \ No newline at end of file diff --git a/src/api.rs b/src/api.rs index d6f4fbe..c7b134b 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,574 +1,311 @@ -use std::ffi::c_char; use std::ffi::c_void; -use std::ffi::CString; -use std::path::Path; -use std::path::PathBuf; use std::ptr::null; use std::slice; -use derive_builder::Builder; -use diffusion_rs_sys::free_upscaler_ctx; -use diffusion_rs_sys::new_upscaler_ctx; use diffusion_rs_sys::sd_image_t; -use diffusion_rs_sys::upscaler_ctx_t; +use image::ImageBuffer; +use image::Rgb; +use image::RgbImage; use libc::free; -use thiserror::Error; use diffusion_rs_sys::free_sd_ctx; use diffusion_rs_sys::new_sd_ctx; use diffusion_rs_sys::sd_ctx_t; -use diffusion_rs_sys::stbi_write_png_custom; -/// Specify the range function -pub use diffusion_rs_sys::rng_type_t as RngFunction; +use crate::model_config::ModelConfig; +use crate::txt2img_config::Txt2ImgConfig; +use crate::utils::CLibString; +use crate::utils::DiffusionError; +struct ModelCtx { + /// The underlying C context + raw_ctx: *mut sd_ctx_t, -/// Sampling methods -pub use diffusion_rs_sys::sample_method_t as SampleMethod; - -/// Denoiser sigma schedule -pub use diffusion_rs_sys::schedule_t as Schedule; - -/// Weight type -pub use diffusion_rs_sys::sd_type_t as WeightType; - -#[non_exhaustive] -#[derive(Error, Debug)] -/// Error that can occurs while forwarding models -pub enum DiffusionError { - #[error("The underling stablediffusion.cpp function returned NULL")] - Forward, - #[error("The underling stbi_write_image function returned 0 while saving image {0}/{1})")] - StoreImages(usize, i32), - #[error("The underling upsclaer model returned a NULL image")] - Upscaler, -} - -#[repr(i32)] -#[non_exhaustive] -#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)] -/// Ignore the lower X layers of CLIP network -pub enum ClipSkip { - /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x - #[default] - Unspecified = 0, - None = 1, - OneLayer = 2, + /// We keep the config around in case we need to refer to it + pub model_config: ModelConfig, } -#[derive(Builder, Debug, Clone)] -#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))] - -/// Config struct common to all diffusion methods -pub struct Config { - /// Number of threads to use during computation (default: 0). - /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores. - #[builder(default = "num_cpus::get_physical() as i32", setter(custom))] - n_threads: i32, - - /// Path to full model - #[builder(default = "Default::default()")] - model: CLibPath, - - /// Path to the standalone diffusion model - #[builder(default = "Default::default()")] - diffusion_model: CLibPath, - - /// path to the clip-l text encoder - #[builder(default = "Default::default()")] - clip_l: CLibPath, - - /// path to the clip-g text encoder - #[builder(default = "Default::default()")] - clip_g: CLibPath, - - /// Path to the t5xxl text encoder - #[builder(default = "Default::default()")] - t5xxl: CLibPath, - - /// Path to vae - #[builder(default = "Default::default()")] - vae: CLibPath, - - /// Path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) - #[builder(default = "Default::default()")] - taesd: CLibPath, - - /// Path to control net model - #[builder(default = "Default::default()")] - control_net: CLibPath, - - /// Path to embeddings - #[builder(default = "Default::default()")] - embeddings: CLibPath, - - /// Path to PHOTOMAKER stacked id embeddings - #[builder(default = "Default::default()")] - stacked_id_embd: CLibPath, - - /// Path to PHOTOMAKER input id images dir - #[builder(default = "Default::default()")] - input_id_images: CLibPath, - - /// Normalize PHOTOMAKER input id images - #[builder(default = "false")] - normalize_input: bool, - - /// Path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now - #[builder(default = "Default::default()")] - upscale_model: Option, - - /// Run the ESRGAN upscaler this many times (default 1) - #[builder(default = "0")] - upscale_repeats: i32, - - /// Weight type. If not specified, the default is the type of the weight file - #[builder(default = "WeightType::SD_TYPE_COUNT")] - weight_type: WeightType, - - /// Lora model directory - #[builder(default = "Default::default()", setter(custom))] - lora_model: CLibPath, - - /// Path to the input image, required by img2img - #[builder(default = "Default::default()")] - init_img: CLibPath, - - /// Path to image condition, control net - #[builder(default = "Default::default()")] - control_image: CLibPath, - - /// Path to write result image to (default: ./output.png) - #[builder(default = "PathBuf::from(\"./output.png\")")] - output: PathBuf, - - /// The prompt to render - prompt: String, - - /// The negative prompt (default: "") - #[builder(default = "\"\".into()")] - negative_prompt: CLibString, - - /// Unconditional guidance scale (default: 7.0) - #[builder(default = "7.0")] - cfg_scale: f32, - - /// Guidance (default: 3.5) - #[builder(default = "3.5")] - guidance: f32, - - /// Strength for noising/unnoising (default: 0.75) - #[builder(default = "0.75")] - strength: f32, - - /// Strength for keeping input identity (default: 20%) - #[builder(default = "20.0")] - style_ratio: f32, - - /// Strength to apply Control Net (default: 0.9) - /// 1.0 corresponds to full destruction of information in init - #[builder(default = "0.9")] - control_strength: f32, - - /// Image height, in pixel space (default: 512) - #[builder(default = "512")] - height: i32, - - /// Image width, in pixel space (default: 512) - #[builder(default = "512")] - width: i32, - - /// Sampling-method (default: EULER_A) - #[builder(default = "SampleMethod::EULER_A")] - sampling_method: SampleMethod, - - /// Number of sample steps (default: 20) - #[builder(default = "20")] - steps: i32, - - /// RNG (default: CUDA) - #[builder(default = "RngFunction::CUDA_RNG")] - rng: RngFunction, - - /// RNG seed (default: 42, use random seed for < 0) - #[builder(default = "42")] - seed: i64, - - /// Number of images to generate (default: 1) - #[builder(default = "1")] - batch_count: i32, - - /// Denoiser sigma schedule (default: DEFAULT) - #[builder(default = "Schedule::DEFAULT")] - schedule: Schedule, - - /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) - /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x - #[builder(default = "ClipSkip::Unspecified")] - clip_skip: ClipSkip, - - /// Process vae in tiles to reduce memory usage (default: false) - #[builder(default = "false")] - vae_tiling: bool, - - /// free memory of params immediately after forward (default: true) - #[builder(default = "true")] - free_params_immediately: bool, - - /// Keep vae in cpu (for low vram) (default: false) - #[builder(default = "false")] - vae_on_cpu: bool, - - /// keep clip in cpu (for low vram) (default: false) - #[builder(default = "false")] - clip_on_cpu: bool, - - /// Keep controlnet in cpu (for low vram) (default: false) - #[builder(default = "false")] - control_net_cpu: bool, - - /// Apply canny preprocessor (edge detection) (default: false) - #[builder(default = "false")] - canny: bool, - - /// Suffix that needs to be added to prompt (e.g. lora model) - #[builder(default = "None", private)] - prompt_suffix: Option, - - /// Use flash attention in the diffusion model (for low vram). - /// Might lower quality, since it implies converting k and v to f16. - /// This might crash if it is not supported by the backend. - #[builder(default = "false")] - flash_attention: bool, - - /// skip layer guidance (SLG) scale, only for DiT models: (default: 0) - /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium - #[builder(default = "0.")] - slg_scale: f32, - - /// Layers to skip for SLG steps: (default: [7,8,9]) - #[builder(default = "vec![7, 8, 9]")] - skip_layer: Vec, - - /// SLG enabling point: (default: 0.01) - #[builder(default = "0.01")] - skip_layer_start: f32, +impl ModelCtx { + pub fn new(config: ModelConfig) -> Self { + let raw_ctx = unsafe { + new_sd_ctx( + config.model.as_ptr(), + config.clip_l.as_ptr(), + config.clip_g.as_ptr(), + config.t5xxl.as_ptr(), + config.diffusion_model.as_ptr(), + config.vae.as_ptr(), + config.taesd.as_ptr(), + config.control_net.as_ptr(), + config.lora_model_dir.as_ptr(), + config.embeddings_dir.as_ptr(), + config.stacked_id_embd_dir.as_ptr(), + config.vae_decode_only, + config.vae_tiling, + config.free_params_immediately, + config.n_threads, + config.weight_type, + config.rng_type, + config.schedule, + config.keep_clip_on_cpu, + config.keep_control_net_cpu, + config.keep_vae_on_cpu, + config.flash_attention, + ) + }; - /// SLG disabling point: (default: 0.2) - #[builder(default = "0.2")] - skip_layer_end: f32, -} + Self { + raw_ctx, + model_config: config, + } + } -impl ConfigBuilder { - /// add Lora model and clip strength to the prompt suffix - /// e.g. "" - pub fn lora_model(&mut self, lora_model: &Path, clip_strength: f32) -> &mut Self { - let folder = lora_model.parent().unwrap(); - let file_name = lora_model.file_stem().unwrap().to_str().unwrap().to_owned(); - self.prompt_suffix(format!("")); - self.lora_model = Some(folder.into()); - self + pub fn destroy(&mut self) { + unsafe { + if !self.raw_ctx.is_null() { + free_sd_ctx(self.raw_ctx); + self.raw_ctx = std::ptr::null_mut(); + } + } } - pub fn n_threads(&mut self, value: i32) -> &mut Self { - self.n_threads = if value > 0 { - Some(value) - } else { - Some(num_cpus::get_physical() as i32) + pub fn txt2img( + &mut self, + mut txt2img_config: Txt2ImgConfig, + ) -> Result, DiffusionError> { + // add loras to prompt as suffix + let prompt: CLibString = { + let mut prompt = txt2img_config.prompt.clone(); + for lora in txt2img_config.lora_prompt_suffix.iter() { + prompt.push_str(lora); + } + prompt.into() }; - self - } - fn validate(&self) -> Result<(), ConfigBuilderError> { - self.validate_model()?; - self.validate_output_dir() - } + //print prompt for debugging + println!( + "Prompt: {:?}", + prompt.0.to_str().expect("Couldn't get string") + ); - fn validate_model(&self) -> Result<(), ConfigBuilderError> { - self.model - .as_ref() - .or(self.diffusion_model.as_ref()) - .map(|_| ()) - .ok_or(ConfigBuilderError::UninitializedField( - "Model OR DiffusionModel must be valorized", - )) - } + //controlnet support - fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> { - let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir()); - let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1); - if is_dir == multiple_items { - Ok(()) - } else { - Err(ConfigBuilderError::ValidationError( - "When batch_count > 0, ouput should point to folder and viceversa".to_owned(), - )) + let results: *mut sd_image_t = unsafe { + diffusion_rs_sys::txt2img( + self.raw_ctx, + prompt.as_ptr(), + txt2img_config.negative_prompt.as_ptr(), + txt2img_config.clip_skip as i32, + txt2img_config.cfg_scale, + txt2img_config.guidance, + txt2img_config.width, + txt2img_config.height, + txt2img_config.sample_method, + txt2img_config.sample_steps, + txt2img_config.seed, + txt2img_config.batch_count, + null(), + txt2img_config.control_strength, + txt2img_config.style_ratio, + txt2img_config.normalize_input, + txt2img_config.input_id_images.as_ptr(), + txt2img_config.skip_layer.as_mut_ptr(), + txt2img_config.skip_layer.len(), + txt2img_config.slg_scale, + txt2img_config.skip_layer_start, + txt2img_config.skip_layer_end, + ) + }; + + if results.is_null() { + return Err(DiffusionError::Forward); } - } -} -impl Config { - unsafe fn build_sd_ctx(&self, vae_decode_only: bool) -> *mut sd_ctx_t { - new_sd_ctx( - self.model.as_ptr(), - self.clip_l.as_ptr(), - self.clip_g.as_ptr(), - self.t5xxl.as_ptr(), - self.diffusion_model.as_ptr(), - self.vae.as_ptr(), - self.taesd.as_ptr(), - self.control_net.as_ptr(), - self.lora_model.as_ptr(), - self.embeddings.as_ptr(), - self.stacked_id_embd.as_ptr(), - vae_decode_only, - self.vae_tiling, - self.free_params_immediately, - self.n_threads, - self.weight_type, - self.rng, - self.schedule, - self.clip_on_cpu, - self.control_net_cpu, - self.vae_on_cpu, - self.flash_attention, - ) - } + let result_images: Vec = unsafe { + let img_count = txt2img_config.batch_count as usize; + let images = slice::from_raw_parts(results, img_count); + let rgb_images: Result, DiffusionError> = images + .iter() + .map(|sd_img| { + let len = (sd_img.width * sd_img.height * sd_img.channel) as usize; + let raw_pixels = slice::from_raw_parts(sd_img.data, len); + let buffer = raw_pixels.to_vec(); + let buffer = ImageBuffer::, _>::from_raw( + sd_img.width as u32, + sd_img.height as u32, + buffer, + ); + Ok(match buffer { + Some(buffer) => RgbImage::from(buffer), + None => return Err(DiffusionError::SDImagetoRustImage), + }) + }) + .collect(); + match rgb_images { + Ok(images) => images, + Err(e) => return Err(e), + } + }; - unsafe fn upscaler_ctx(&self) -> Option<*mut upscaler_ctx_t> { - if self.upscale_model.is_none() || self.upscale_repeats == 0 { - None - } else { - let upscaler = new_upscaler_ctx( - self.upscale_model.as_ref().unwrap().as_ptr(), - self.n_threads, - ); - Some(upscaler) + //Clean-up slice section + unsafe { + free(results as *mut c_void); } + Ok(result_images) } } -#[derive(Debug, Clone, Default)] -struct CLibString(CString); - -impl CLibString { - fn as_ptr(&self) -> *const c_char { - self.0.as_ptr() +/// Automatic cleanup on drop +impl Drop for ModelCtx { + fn drop(&mut self) { + self.destroy(); } } -impl From<&str> for CLibString { - fn from(value: &str) -> Self { - Self(CString::new(value).unwrap()) - } -} +#[cfg(test)] +mod tests { + /// Sampling methods + pub use diffusion_rs_sys::sample_method_t as SampleMethod; + /// Denoiser sigma schedule + pub use diffusion_rs_sys::schedule_t as Schedule; -impl From for CLibString { - fn from(value: String) -> Self { - Self(CString::new(value).unwrap()) - } -} + use crate::utils::ClipSkip; + use crate::{model_config::ModelConfigBuilder, txt2img_config::Txt2ImgConfigBuilder}; -#[derive(Debug, Clone, Default)] -struct CLibPath(CString); + use super::*; + use std::path::PathBuf; -impl CLibPath { - fn as_ptr(&self) -> *const c_char { - self.0.as_ptr() + #[test] + fn test_invalid_model_config() { + let config = ModelConfigBuilder::default().build(); + assert!(config.is_err(), "ModelConfig should fail without a model"); } -} -impl From for CLibPath { - fn from(value: PathBuf) -> Self { - Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) + #[test] + fn test_valid_model_config() { + let config = ModelConfigBuilder::default() + .model(PathBuf::from("./test.ckpt")) + .build(); + assert!(config.is_ok(), "ModelConfig should succeed with model path"); } -} -impl From<&Path> for CLibPath { - fn from(value: &Path) -> Self { - Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) + #[test] + fn test_invalid_txt2img_config() { + let config = Txt2ImgConfigBuilder::default().build(); + assert!(config.is_err(), "Txt2ImgConfig should fail without prompt"); } -} -fn output_files(path: PathBuf, batch_size: i32) -> Vec { - if batch_size == 1 { - vec![path.into()] - } else { - (1..=batch_size) - .map(|id| path.join(format!("output_{id}.png")).into()) - .collect() + #[test] + fn test_valid_txt2img_config() { + let config = Txt2ImgConfigBuilder::default() + .prompt("testing prompt") + .build(); + assert!(config.is_ok(), "Txt2ImgConfig should succeed with prompt"); } -} -unsafe fn upscale( - upscale_repeats: i32, - upscaler_ctx: Option<*mut upscaler_ctx_t>, - data: sd_image_t, -) -> Result { - match upscaler_ctx { - Some(upscaler_ctx) => { - let upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth - let mut current_image = data; - for _ in 0..upscale_repeats { - let upscaled_image = - diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor); - - if upscaled_image.data.is_null() { - return Err(DiffusionError::Upscaler); - } - - free(current_image.data as *mut c_void); - current_image = upscaled_image; - } - Ok(current_image) - } - None => Ok(data), + #[test] + fn test_model_ctx_new_invalid() { + let config = ModelConfigBuilder::default().build(); + assert!(config.is_err()); + // Attempt creating ModelCtx with error + // This is hypothetical; we expect a builder error before this } -} -/// Generate an image with a prompt -pub fn txt2img(mut config: Config) -> Result<(), DiffusionError> { - unsafe { - let prompt: CLibString = match &config.prompt_suffix { - Some(suffix) => format!("{} {suffix}", &config.prompt), - None => config.prompt.clone(), - } - .into(); - let sd_ctx = config.build_sd_ctx(true); - let upscaler_ctx = config.upscaler_ctx(); - let res = { - let slice = diffusion_rs_sys::txt2img( - sd_ctx, - prompt.as_ptr(), - config.negative_prompt.as_ptr(), - config.clip_skip as i32, - config.cfg_scale, - config.guidance, - config.width, - config.height, - config.sampling_method, - config.steps, - config.seed, - config.batch_count, - null(), - config.control_strength, - config.style_ratio, - config.normalize_input, - config.input_id_images.as_ptr(), - config.skip_layer.as_mut_ptr(), - config.skip_layer.len(), - config.slg_scale, - config.skip_layer_start, - config.skip_layer_end, - ); - if slice.is_null() { - return Err(DiffusionError::Forward); - } - let files = output_files(config.output, config.batch_count); - for (id, (img, path)) in slice::from_raw_parts(slice, config.batch_count as usize) - .iter() - .zip(files) - .enumerate() - { - match upscale(config.upscale_repeats, upscaler_ctx, *img) { - Ok(img) => { - let status = stbi_write_png_custom( - path.as_ptr(), - img.width as i32, - img.height as i32, - img.channel as i32, - img.data as *const c_void, - 0, - ); - if status == 0 { - return Err(DiffusionError::StoreImages(id, config.batch_count)); - } + #[test] + fn test_txt2img_success() { + let config = ModelConfigBuilder::default() + .model(PathBuf::from("./models/mistoonAnime_v30.safetensors")) + .lora_model_dir(PathBuf::from("./models/loras")) + .taesd(PathBuf::from("./models/taesd1.safetensors")) + .flash_attention(true) + .schedule(Schedule::AYS) + .build() + .expect("Failed to build model config"); + let mut ctx = ModelCtx::new(config.clone()); + let txt2img_conf = Txt2ImgConfigBuilder::default() + .prompt("masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, portrait, purple skin") + .add_lora_model("pcm_sd15_lcmlike_lora_converted".to_owned(), 1.0) + .sample_steps(2) + .sample_method(SampleMethod::LCM) + .cfg_scale(1.0) + .height(256) + .width(256) + .clip_skip(ClipSkip::OneLayer) + .build() + .expect("Failed to build txt2img config"); + let txt2img_conf2 = Txt2ImgConfigBuilder::default() + .prompt("masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, portrait, golden skin") + .add_lora_model("pcm_sd15_lcmlike_lora_converted".to_owned(), 1.0) + .sample_steps(2) + .sample_method(SampleMethod::LCM) + .cfg_scale(1.0) + .height(256) + .width(256) + .clip_skip(ClipSkip::OneLayer) + .build() + .expect("Failed to build txt2img config"); + let result = ctx.txt2img(txt2img_conf); + let result2 = ctx.txt2img(txt2img_conf2); + match result { + Ok(images) => { + //save image for testing + images.iter().enumerate().for_each(|(i, img)| { + img.save(format!("./test_image_{}.png", i)).unwrap(); + }); + match result2 { + Ok(images) => { + //save image for testing + images.iter().enumerate().for_each(|(i, img)| { + img.save(format!("./test_image2_{}.png", i)).unwrap(); + }); } - Err(err) => { - return Err(err); + Err(e) => { + panic!("Error: {:?}", e); } } } - - //Clean-up slice section - free(slice as *mut c_void); - Ok(()) + Err(e) => { + panic!("Error: {:?}", e); + } }; - - //Clean-up CTX section - free_sd_ctx(sd_ctx); - if let Some(upscaler_ctx) = upscaler_ctx { - free_upscaler_ctx(upscaler_ctx); - } - res } -} - -#[cfg(test)] -mod tests { - use std::path::PathBuf; - - use crate::{api::ConfigBuilderError, util::download_file_hf_hub}; - - use super::{txt2img, ConfigBuilder}; #[test] - fn test_required_args_txt2img() { - assert!(ConfigBuilder::default().build().is_err()); - assert!(ConfigBuilder::default() - .model(PathBuf::from("./test.ckpt")) - .build() - .is_err()); - - assert!(ConfigBuilder::default() - .prompt("a lovely cat driving a sport car") - .build() - .is_err()); - - assert!(matches!( - ConfigBuilder::default() - .model(PathBuf::from("./test.ckpt")) - .prompt("a lovely cat driving a sport car") - .batch_count(10) - .build(), - Err(ConfigBuilderError::ValidationError(_)) - )); - - ConfigBuilder::default() - .model(PathBuf::from("./test.ckpt")) - .prompt("a lovely cat driving a sport car") + fn test_txt2img_failure() { + // Build a context with invalid data to force failure + let config = ModelConfigBuilder::default() + .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) .build() .unwrap(); - - ConfigBuilder::default() - .model(PathBuf::from("./test.ckpt")) - .prompt("a lovely duck drinking water from a bottle") - .batch_count(2) - .output(PathBuf::from("./")) + let mut ctx = ModelCtx::new(config); + let txt2img_conf = Txt2ImgConfigBuilder::default() + .prompt("test prompt") + .sample_steps(1) .build() .unwrap(); + // Hypothetical failure scenario + let result = ctx.txt2img(txt2img_conf); + // Expect an error if calling with invalid path + // This depends on your real implementation + assert!(result.is_err() || result.is_ok()); } - #[ignore] #[test] - fn test_txt2img() { - let model_path = - download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt") - .unwrap(); - - let upscaler_path = download_file_hf_hub( - "ximso/RealESRGAN_x4plus_anime_6B", - "RealESRGAN_x4plus_anime_6B.pth", - ) - .unwrap(); - let config = ConfigBuilder::default() - .model(model_path) - .prompt("a lovely duck drinking water from a bottle") - .output(PathBuf::from("./output_1.png")) - .upscale_model(upscaler_path) - .upscale_repeats(1) - .batch_count(1) + fn test_multiple_images() { + let config = ModelConfigBuilder::default() + .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) + .build() + .unwrap(); + let mut ctx = ModelCtx::new(config); + let txt2img_conf = Txt2ImgConfigBuilder::default() + .prompt("multi-image prompt") + .sample_steps(1) + .batch_count(3) .build() .unwrap(); - txt2img(config).unwrap(); + let result = ctx.txt2img(txt2img_conf); + assert!(result.is_ok()); + if let Ok(images) = result { + assert_eq!(images.len(), 3); + } } } diff --git a/src/api_modified.rs b/src/api_modified.rs deleted file mode 100644 index ed03555..0000000 --- a/src/api_modified.rs +++ /dev/null @@ -1,570 +0,0 @@ -use std::ffi::c_char; -use std::ffi::c_void; -use std::ffi::CString; -use std::path::Path; -use std::path::PathBuf; -use std::ptr::null; -use std::slice; - -use derive_builder::Builder; -use diffusion_rs_sys::sd_image_t; -use image::ImageBuffer; -use image::Rgb; -use image::RgbImage; -use libc::free; -use thiserror::Error; - -use diffusion_rs_sys::free_sd_ctx; -use diffusion_rs_sys::new_sd_ctx; -use diffusion_rs_sys::sd_ctx_t; - -/// Specify the range function -pub use diffusion_rs_sys::rng_type_t as RngFunction; - -/// Sampling methods -pub use diffusion_rs_sys::sample_method_t as SampleMethod; - -/// Denoiser sigma schedule -pub use diffusion_rs_sys::schedule_t as Schedule; - -/// Weight type -pub use diffusion_rs_sys::sd_type_t as WeightType; - -#[non_exhaustive] -#[derive(Error, Debug)] -/// Error that can occurs while forwarding models -pub enum DiffusionError { - #[error("The underling stablediffusion.cpp function returned NULL")] - Forward, - #[error("The underling stbi_write_image function returned 0 while saving image {0}/{1})")] - StoreImages(usize, i32), - #[error("The underling upscaler model returned a NULL image")] - Upscaler, - #[error("Failed to convert image buffer to rust type")] - SDImagetoRustImage, - // #[error("Free Params Immediately is set to true, which means that the params are freed after forward. This means that the model can only be used once")] - // FreeParamsImmediately, -} - -#[repr(i32)] -#[non_exhaustive] -#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)] -/// Ignore the lower X layers of CLIP network -pub enum ClipSkip { - /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x - #[default] - Unspecified = 0, - None = 1, - OneLayer = 2, -} - -#[derive(Debug, Clone, Default)] -struct CLibString(CString); - -impl CLibString { - fn as_ptr(&self) -> *const c_char { - self.0.as_ptr() - } -} - -impl From<&str> for CLibString { - fn from(value: &str) -> Self { - Self(CString::new(value).unwrap()) - } -} - -impl From for CLibString { - fn from(value: String) -> Self { - Self(CString::new(value).unwrap()) - } -} - -#[derive(Debug, Clone, Default)] -struct CLibPath(CString); - -impl CLibPath { - fn as_ptr(&self) -> *const c_char { - self.0.as_ptr() - } -} - -impl From for CLibPath { - fn from(value: PathBuf) -> Self { - Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) - } -} - -impl From<&Path> for CLibPath { - fn from(value: &Path) -> Self { - Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) - } -} - -#[derive(Builder, Debug, Clone)] -#[builder(setter(into), build_fn(validate = "Self::validate"))] -/// Config struct common to all diffusion methods -pub struct ModelConfig { - /// Path to full model - #[builder(default = "Default::default()")] - model: CLibPath, - - /// path to the clip-l text encoder - #[builder(default = "Default::default()")] - clip_l: CLibPath, - - /// path to the clip-g text encoder - #[builder(default = "Default::default()")] - clip_g: CLibPath, - - /// Path to the t5xxl text encoder - #[builder(default = "Default::default()")] - t5xxl: CLibPath, - - /// Path to the standalone diffusion model - #[builder(default = "Default::default()")] - diffusion_model: CLibPath, - - /// Path to vae - #[builder(default = "Default::default()")] - vae: CLibPath, - - /// Path to taesd. Using Tiny AutoEncoder for fast decoding (lower quality) - #[builder(default = "Default::default()")] - taesd: CLibPath, - - /// Path to control net model - #[builder(default = "Default::default()")] - control_net: CLibPath, - - /// Lora models directory - #[builder(default = "Default::default()", setter(custom))] - lora_model_dir: CLibPath, - - /// Path to embeddings directory - #[builder(default = "Default::default()")] - embeddings_dir: CLibPath, - - /// Path to PHOTOMAKER stacked id embeddings - #[builder(default = "Default::default()")] - stacked_id_embd_dir: CLibPath, - - //TODO: Add more info here for docs - /// vae decode only (default: false) - #[builder(default = "false")] - vae_decode_only: bool, - - /// Process vae in tiles to reduce memory usage (default: false) - #[builder(default = "false")] - vae_tiling: bool, - - /// free memory of params immediately after forward (default: false) - #[builder(default = "false")] - free_params_immediately: bool, - - /// Number of threads to use during computation (default: 0). - /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores. - #[builder( - default = "std::thread::available_parallelism().map_or(1, |p| p.get() as i32)", - setter(custom) - )] - n_threads: i32, - - /// Weight type. If not specified, the default is the type of the weight file - #[builder(default = "WeightType::SD_TYPE_COUNT")] - weight_type: WeightType, - - /// RNG type (default: CUDA) - #[builder(default = "RngFunction::CUDA_RNG")] - rng_type: RngFunction, - - /// Denoiser sigma schedule (default: DEFAULT) - #[builder(default = "Schedule::DEFAULT")] - schedule: Schedule, - - /// keep clip on cpu (for low vram) (default: false) - #[builder(default = "false")] - keep_clip_on_cpu: bool, - - /// Keep controlnet in cpu (for low vram) (default: false) - #[builder(default = "false")] - keep_control_net_cpu: bool, - - /// Keep vae on cpu (for low vram) (default: false) - #[builder(default = "false")] - keep_vae_on_cpu: bool, - - /// Use flash attention in the diffusion model (for low vram). - /// Might lower quality, since it implies converting k and v to f16. - /// This might crash if it is not supported by the backend. - /// must have feature "flash_attention" enabled in the features. - /// (default: false) - #[builder(default = "false")] - flash_attention: bool, -} - -impl ModelConfigBuilder { - pub fn n_threads(&mut self, value: i32) -> &mut Self { - self.n_threads = if value > 0 { - Some(value) - } else { - Some(std::thread::available_parallelism().map_or(1, |p| p.get() as i32)) - }; - self - } - - fn validate(&self) -> Result<(), ModelConfigBuilderError> { - self.validate_model() - } - - fn validate_model(&self) -> Result<(), ModelConfigBuilderError> { - self.model - .as_ref() - .or(self.diffusion_model.as_ref()) - .map(|_| ()) - .ok_or(ModelConfigBuilderError::UninitializedField( - "Model OR DiffusionModel must be initialized", - )) - } -} - -#[derive(Builder, Debug, Clone)] -#[builder(setter(into), build_fn(validate = "Self::validate"))] -/// txt2img config -struct Txt2ImgConfig { - /// Prompt to generate image from - prompt: String, - - /// Suffix that needs to be added to prompt (e.g. lora model) - #[builder(default = "Default::default()", private)] - lora_prompt_suffix: Vec, - - /// The negative prompt (default: "") - #[builder(default = "\"\".into()")] - negative_prompt: CLibString, - - /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) - /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x - #[builder(default = "ClipSkip::Unspecified")] - clip_skip: ClipSkip, - - /// Unconditional guidance scale (default: 7.0) - #[builder(default = "7.0")] - cfg_scale: f32, - - /// Guidance (default: 3.5) - #[builder(default = "3.5")] - guidance: f32, - - /// Image height, in pixel space (default: 512) - #[builder(default = "512")] - height: i32, - - /// Image width, in pixel space (default: 512) - #[builder(default = "512")] - width: i32, - - /// Sampling-method (default: EULER_A) - #[builder(default = "SampleMethod::EULER_A")] - sample_method: SampleMethod, - - /// Number of sample steps (default: 20) - #[builder(default = "20")] - sample_steps: i32, - - /// RNG seed (default: 42, use random seed for < 0) - #[builder(default = "42")] - seed: i64, - - /// Number of images to generate (default: 1) - #[builder(default = "1")] - batch_count: i32, - - /// Strength to apply Control Net (default: 0.9) - /// 1.0 corresponds to full destruction of information in init - #[builder(default = "0.9")] - control_strength: f32, - - /// Strength for keeping input identity (default: 20%) - #[builder(default = "20.0")] - style_ratio: f32, - - /// Normalize PHOTOMAKER input id images - #[builder(default = "false")] - normalize_input: bool, - - /// Path to PHOTOMAKER input id images dir - #[builder(default = "Default::default()")] - input_id_images: CLibPath, - - /// Layers to skip for SLG steps: (default: [7,8,9]) - #[builder(default = "vec![7, 8, 9]")] - skip_layer: Vec, - - /// skip layer guidance (SLG) scale, only for DiT models: (default: 0) - /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium - #[builder(default = "0.")] - slg_scale: f32, - - /// SLG enabling point: (default: 0.01) - #[builder(default = "0.01")] - skip_layer_start: f32, - - /// SLG disabling point: (default: 0.2) - #[builder(default = "0.2")] - skip_layer_end: f32, -} - -impl Txt2ImgConfigBuilder { - fn validate(&self) -> Result<(), Txt2ImgConfigBuilderError> { - self.validate_prompt() - } - - fn validate_prompt(&self) -> Result<(), Txt2ImgConfigBuilderError> { - self.prompt - .as_ref() - .map(|_| ()) - .ok_or(Txt2ImgConfigBuilderError::UninitializedField("Prompt")) - } - - pub fn add_lora_model(&mut self, filename: String, strength: f32) -> &mut Self { - self.lora_prompt_suffix - .get_or_insert_with(Vec::new) - .push(format!("")); - self - } -} - -struct ModelCtx { - /// The underlying C context - raw_ctx: *mut sd_ctx_t, - - /// We keep the config around in case we need to refer to it - pub model_config: ModelConfig, -} - -impl ModelCtx { - pub fn new(config: ModelConfig) -> Self { - let raw_ctx = unsafe { - new_sd_ctx( - config.model.as_ptr(), - config.clip_l.as_ptr(), - config.clip_g.as_ptr(), - config.t5xxl.as_ptr(), - config.diffusion_model.as_ptr(), - config.vae.as_ptr(), - config.taesd.as_ptr(), - config.control_net.as_ptr(), - config.lora_model_dir.as_ptr(), - config.embeddings_dir.as_ptr(), - config.stacked_id_embd_dir.as_ptr(), - config.vae_decode_only, - config.vae_tiling, - config.free_params_immediately, - config.n_threads, - config.weight_type, - config.rng_type, - config.schedule, - config.keep_clip_on_cpu, - config.keep_control_net_cpu, - config.keep_vae_on_cpu, - config.flash_attention, - ) - }; - - Self { - raw_ctx, - model_config: config, - } - } - - pub fn destroy(&mut self) { - unsafe { - if !self.raw_ctx.is_null() { - free_sd_ctx(self.raw_ctx); - self.raw_ctx = std::ptr::null_mut(); - } - } - } - - pub fn txt2img( - &mut self, - mut txt2img_config: Txt2ImgConfig, - ) -> Result, DiffusionError> { - // add loras to prompt as suffix - let prompt: CLibString = { - let mut prompt = txt2img_config.prompt.clone(); - for lora in txt2img_config.lora_prompt_suffix.iter() { - prompt.push_str(lora); - } - prompt.into() - }; - - let results: *mut sd_image_t = unsafe { - diffusion_rs_sys::txt2img( - self.raw_ctx, - prompt.as_ptr(), - txt2img_config.negative_prompt.as_ptr(), - txt2img_config.clip_skip as i32, - txt2img_config.cfg_scale, - txt2img_config.guidance, - txt2img_config.width, - txt2img_config.height, - txt2img_config.sample_method, - txt2img_config.sample_steps, - txt2img_config.seed, - txt2img_config.batch_count, - null(), - txt2img_config.control_strength, - txt2img_config.style_ratio, - txt2img_config.normalize_input, - txt2img_config.input_id_images.as_ptr(), - txt2img_config.skip_layer.as_mut_ptr(), - txt2img_config.skip_layer.len(), - txt2img_config.slg_scale, - txt2img_config.skip_layer_start, - txt2img_config.skip_layer_end, - ) - }; - - if results.is_null() { - return Err(DiffusionError::Forward); - } - - let result_images: Vec = unsafe { - let img_count = txt2img_config.batch_count as usize; - let images = slice::from_raw_parts(results, img_count); - let rgb_images: Result, DiffusionError> = images - .iter() - .map(|sd_img| { - let len = (sd_img.width * sd_img.height * sd_img.channel) as usize; - let raw_pixels = slice::from_raw_parts(sd_img.data, len); - let buffer = raw_pixels.to_vec(); - let buffer = ImageBuffer::, _>::from_raw( - sd_img.width as u32, - sd_img.height as u32, - buffer, - ); - Ok(match buffer { - Some(buffer) => RgbImage::from(buffer), - None => return Err(DiffusionError::SDImagetoRustImage), - }) - }) - .collect(); - match rgb_images { - Ok(images) => images, - Err(e) => return Err(e), - } - }; - - //Clean-up slice section - unsafe { - free(results as *mut c_void); - } - Ok(result_images) - } -} - -/// Automatic cleanup on drop -impl Drop for ModelCtx { - fn drop(&mut self) { - self.destroy(); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::path::PathBuf; - - #[test] - fn test_invalid_model_config() { - let config = ModelConfigBuilder::default().build(); - assert!(config.is_err(), "ModelConfig should fail without a model"); - } - - #[test] - fn test_valid_model_config() { - let config = ModelConfigBuilder::default() - .model(PathBuf::from("./test.ckpt")) - .build(); - assert!(config.is_ok(), "ModelConfig should succeed with model path"); - } - - #[test] - fn test_invalid_txt2img_config() { - let config = Txt2ImgConfigBuilder::default().build(); - assert!(config.is_err(), "Txt2ImgConfig should fail without prompt"); - } - - #[test] - fn test_valid_txt2img_config() { - let config = Txt2ImgConfigBuilder::default() - .prompt("testing prompt") - .build(); - assert!(config.is_ok(), "Txt2ImgConfig should succeed with prompt"); - } - - #[test] - fn test_model_ctx_new_invalid() { - let config = ModelConfigBuilder::default().build(); - assert!(config.is_err()); - // Attempt creating ModelCtx with error - // This is hypothetical; we expect a builder error before this - } - - #[test] - fn test_txt2img_success() { - let config = ModelConfigBuilder::default() - .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) - .build() - .unwrap(); - let mut ctx = ModelCtx::new(config.clone()); - let txt2img_conf = Txt2ImgConfigBuilder::default() - .prompt("test prompt") - .sample_steps(1) - .build() - .unwrap(); - let result = ctx.txt2img(txt2img_conf); - assert!(result.is_ok()); - } - - #[test] - fn test_txt2img_failure() { - // Build a context with invalid data to force failure - let config = ModelConfigBuilder::default() - .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) - .build() - .unwrap(); - let mut ctx = ModelCtx::new(config); - let txt2img_conf = Txt2ImgConfigBuilder::default() - .prompt("test prompt") - .sample_steps(1) - .build() - .unwrap(); - // Hypothetical failure scenario - let result = ctx.txt2img(txt2img_conf); - // Expect an error if calling with invalid path - // This depends on your real implementation - assert!(result.is_err() || result.is_ok()); - } - - #[test] - fn test_multiple_images() { - let config = ModelConfigBuilder::default() - .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) - .build() - .unwrap(); - let mut ctx = ModelCtx::new(config); - let txt2img_conf = Txt2ImgConfigBuilder::default() - .prompt("multi-image prompt") - .sample_steps(1) - .batch_count(3) - .build() - .unwrap(); - let result = ctx.txt2img(txt2img_conf); - assert!(result.is_ok()); - if let Ok(images) = result { - assert_eq!(images.len(), 3); - } - } -} diff --git a/src/img2img_config.rs b/src/img2img_config.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/lib.rs b/src/lib.rs index 4d194a4..6a4c58e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,6 @@ #![doc = include_str!("../README.md")] -/// Modified API -pub mod api_modified; +pub mod api; +pub mod model_config; +pub mod txt2img_config; +pub mod utils; diff --git a/src/model_config.rs b/src/model_config.rs new file mode 100644 index 0000000..f95a077 --- /dev/null +++ b/src/model_config.rs @@ -0,0 +1,139 @@ +use derive_builder::Builder; + +use crate::utils::CLibPath; + +/// Specify the range function +pub use diffusion_rs_sys::rng_type_t as RngFunction; + +/// Denoiser sigma schedule +pub use diffusion_rs_sys::schedule_t as Schedule; + +/// Weight type +pub use diffusion_rs_sys::sd_type_t as WeightType; + +#[derive(Builder, Debug, Clone)] +#[builder(setter(into), build_fn(validate = "Self::validate"))] +/// Config struct common to all diffusion methods +pub struct ModelConfig { + /// Path to full model + #[builder(default = "Default::default()")] + pub model: CLibPath, + + /// path to the clip-l text encoder + #[builder(default = "Default::default()")] + pub clip_l: CLibPath, + + /// path to the clip-g text encoder + #[builder(default = "Default::default()")] + pub clip_g: CLibPath, + + /// Path to the t5xxl text encoder + #[builder(default = "Default::default()")] + pub t5xxl: CLibPath, + + /// Path to the standalone diffusion model + #[builder(default = "Default::default()")] + pub diffusion_model: CLibPath, + + /// Path to vae + #[builder(default = "Default::default()")] + pub vae: CLibPath, + + /// Path to taesd. Using Tiny AutoEncoder for fast decoding (lower quality) + #[builder(default = "Default::default()")] + pub taesd: CLibPath, + + /// Path to control net model + #[builder(default = "Default::default()")] + pub control_net: CLibPath, + + /// Lora models directory + #[builder(default = "Default::default()")] + pub lora_model_dir: CLibPath, + + /// Path to embeddings directory + #[builder(default = "Default::default()")] + pub embeddings_dir: CLibPath, + + /// Path to PHOTOMAKER stacked id embeddings + #[builder(default = "Default::default()")] + pub stacked_id_embd_dir: CLibPath, + + //TODO: Add more info here for docs + /// vae decode only (default: false) + #[builder(default = "false")] + pub vae_decode_only: bool, + + /// Process vae in tiles to reduce memory usage (default: false) + #[builder(default = "false")] + pub vae_tiling: bool, + + /// free memory of params immediately after forward (default: false) + #[builder(default = "false")] + pub free_params_immediately: bool, + + /// Number of threads to use during computation (default: 0). + /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores. + #[builder( + default = "std::thread::available_parallelism().map_or(1, |p| p.get() as i32)", + setter(custom) + )] + pub n_threads: i32, + + /// Weight type. If not specified, the default is the type of the weight file + #[builder(default = "WeightType::SD_TYPE_COUNT")] + pub weight_type: WeightType, + + /// RNG type (default: CUDA) + #[builder(default = "RngFunction::CUDA_RNG")] + pub rng_type: RngFunction, + + /// Denoiser sigma schedule (default: DEFAULT) + #[builder(default = "Schedule::DEFAULT")] + pub schedule: Schedule, + + /// keep clip on cpu (for low vram) (default: false) + #[builder(default = "false")] + pub keep_clip_on_cpu: bool, + + /// Keep controlnet in cpu (for low vram) (default: false) + #[builder(default = "false")] + pub keep_control_net_cpu: bool, + + /// Keep vae on cpu (for low vram) (default: false) + #[builder(default = "false")] + pub keep_vae_on_cpu: bool, + + /// Use flash attention in the diffusion model (for low vram). + /// Might lower quality, since it implies converting k and v to f16. + /// This might crash if it is not supported by the backend. + /// must have feature "flash_attention" enabled in the features. + /// (default: false) + #[builder(default = "false")] + pub flash_attention: bool, +} + +impl ModelConfigBuilder { + pub fn n_threads(&mut self, value: i32) -> &mut Self { + self.n_threads = if value > 0 { + Some(value) + } else { + Some(std::thread::available_parallelism().map_or(1, |p| p.get() as i32)) + }; + self + } + + fn validate(&self) -> Result<(), ModelConfigBuilderError> { + self.validate_model() + } + + fn validate_model(&self) -> Result<(), ModelConfigBuilderError> { + self.model + .as_ref() + .or(self.diffusion_model.as_ref()) + .map(|_| ()) + .ok_or(ModelConfigBuilderError::UninitializedField( + "Model OR DiffusionModel must be initialized", + )) + } +} diff --git a/src/old_api.rs b/src/old_api.rs new file mode 100644 index 0000000..d6f4fbe --- /dev/null +++ b/src/old_api.rs @@ -0,0 +1,574 @@ +use std::ffi::c_char; +use std::ffi::c_void; +use std::ffi::CString; +use std::path::Path; +use std::path::PathBuf; +use std::ptr::null; +use std::slice; + +use derive_builder::Builder; +use diffusion_rs_sys::free_upscaler_ctx; +use diffusion_rs_sys::new_upscaler_ctx; +use diffusion_rs_sys::sd_image_t; +use diffusion_rs_sys::upscaler_ctx_t; +use libc::free; +use thiserror::Error; + +use diffusion_rs_sys::free_sd_ctx; +use diffusion_rs_sys::new_sd_ctx; +use diffusion_rs_sys::sd_ctx_t; +use diffusion_rs_sys::stbi_write_png_custom; + +/// Specify the range function +pub use diffusion_rs_sys::rng_type_t as RngFunction; + +/// Sampling methods +pub use diffusion_rs_sys::sample_method_t as SampleMethod; + +/// Denoiser sigma schedule +pub use diffusion_rs_sys::schedule_t as Schedule; + +/// Weight type +pub use diffusion_rs_sys::sd_type_t as WeightType; + +#[non_exhaustive] +#[derive(Error, Debug)] +/// Error that can occurs while forwarding models +pub enum DiffusionError { + #[error("The underling stablediffusion.cpp function returned NULL")] + Forward, + #[error("The underling stbi_write_image function returned 0 while saving image {0}/{1})")] + StoreImages(usize, i32), + #[error("The underling upsclaer model returned a NULL image")] + Upscaler, +} + +#[repr(i32)] +#[non_exhaustive] +#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)] +/// Ignore the lower X layers of CLIP network +pub enum ClipSkip { + /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x + #[default] + Unspecified = 0, + None = 1, + OneLayer = 2, +} + +#[derive(Builder, Debug, Clone)] +#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))] + +/// Config struct common to all diffusion methods +pub struct Config { + /// Number of threads to use during computation (default: 0). + /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores. + #[builder(default = "num_cpus::get_physical() as i32", setter(custom))] + n_threads: i32, + + /// Path to full model + #[builder(default = "Default::default()")] + model: CLibPath, + + /// Path to the standalone diffusion model + #[builder(default = "Default::default()")] + diffusion_model: CLibPath, + + /// path to the clip-l text encoder + #[builder(default = "Default::default()")] + clip_l: CLibPath, + + /// path to the clip-g text encoder + #[builder(default = "Default::default()")] + clip_g: CLibPath, + + /// Path to the t5xxl text encoder + #[builder(default = "Default::default()")] + t5xxl: CLibPath, + + /// Path to vae + #[builder(default = "Default::default()")] + vae: CLibPath, + + /// Path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) + #[builder(default = "Default::default()")] + taesd: CLibPath, + + /// Path to control net model + #[builder(default = "Default::default()")] + control_net: CLibPath, + + /// Path to embeddings + #[builder(default = "Default::default()")] + embeddings: CLibPath, + + /// Path to PHOTOMAKER stacked id embeddings + #[builder(default = "Default::default()")] + stacked_id_embd: CLibPath, + + /// Path to PHOTOMAKER input id images dir + #[builder(default = "Default::default()")] + input_id_images: CLibPath, + + /// Normalize PHOTOMAKER input id images + #[builder(default = "false")] + normalize_input: bool, + + /// Path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now + #[builder(default = "Default::default()")] + upscale_model: Option, + + /// Run the ESRGAN upscaler this many times (default 1) + #[builder(default = "0")] + upscale_repeats: i32, + + /// Weight type. If not specified, the default is the type of the weight file + #[builder(default = "WeightType::SD_TYPE_COUNT")] + weight_type: WeightType, + + /// Lora model directory + #[builder(default = "Default::default()", setter(custom))] + lora_model: CLibPath, + + /// Path to the input image, required by img2img + #[builder(default = "Default::default()")] + init_img: CLibPath, + + /// Path to image condition, control net + #[builder(default = "Default::default()")] + control_image: CLibPath, + + /// Path to write result image to (default: ./output.png) + #[builder(default = "PathBuf::from(\"./output.png\")")] + output: PathBuf, + + /// The prompt to render + prompt: String, + + /// The negative prompt (default: "") + #[builder(default = "\"\".into()")] + negative_prompt: CLibString, + + /// Unconditional guidance scale (default: 7.0) + #[builder(default = "7.0")] + cfg_scale: f32, + + /// Guidance (default: 3.5) + #[builder(default = "3.5")] + guidance: f32, + + /// Strength for noising/unnoising (default: 0.75) + #[builder(default = "0.75")] + strength: f32, + + /// Strength for keeping input identity (default: 20%) + #[builder(default = "20.0")] + style_ratio: f32, + + /// Strength to apply Control Net (default: 0.9) + /// 1.0 corresponds to full destruction of information in init + #[builder(default = "0.9")] + control_strength: f32, + + /// Image height, in pixel space (default: 512) + #[builder(default = "512")] + height: i32, + + /// Image width, in pixel space (default: 512) + #[builder(default = "512")] + width: i32, + + /// Sampling-method (default: EULER_A) + #[builder(default = "SampleMethod::EULER_A")] + sampling_method: SampleMethod, + + /// Number of sample steps (default: 20) + #[builder(default = "20")] + steps: i32, + + /// RNG (default: CUDA) + #[builder(default = "RngFunction::CUDA_RNG")] + rng: RngFunction, + + /// RNG seed (default: 42, use random seed for < 0) + #[builder(default = "42")] + seed: i64, + + /// Number of images to generate (default: 1) + #[builder(default = "1")] + batch_count: i32, + + /// Denoiser sigma schedule (default: DEFAULT) + #[builder(default = "Schedule::DEFAULT")] + schedule: Schedule, + + /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) + /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x + #[builder(default = "ClipSkip::Unspecified")] + clip_skip: ClipSkip, + + /// Process vae in tiles to reduce memory usage (default: false) + #[builder(default = "false")] + vae_tiling: bool, + + /// free memory of params immediately after forward (default: true) + #[builder(default = "true")] + free_params_immediately: bool, + + /// Keep vae in cpu (for low vram) (default: false) + #[builder(default = "false")] + vae_on_cpu: bool, + + /// keep clip in cpu (for low vram) (default: false) + #[builder(default = "false")] + clip_on_cpu: bool, + + /// Keep controlnet in cpu (for low vram) (default: false) + #[builder(default = "false")] + control_net_cpu: bool, + + /// Apply canny preprocessor (edge detection) (default: false) + #[builder(default = "false")] + canny: bool, + + /// Suffix that needs to be added to prompt (e.g. lora model) + #[builder(default = "None", private)] + prompt_suffix: Option, + + /// Use flash attention in the diffusion model (for low vram). + /// Might lower quality, since it implies converting k and v to f16. + /// This might crash if it is not supported by the backend. + #[builder(default = "false")] + flash_attention: bool, + + /// skip layer guidance (SLG) scale, only for DiT models: (default: 0) + /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium + #[builder(default = "0.")] + slg_scale: f32, + + /// Layers to skip for SLG steps: (default: [7,8,9]) + #[builder(default = "vec![7, 8, 9]")] + skip_layer: Vec, + + /// SLG enabling point: (default: 0.01) + #[builder(default = "0.01")] + skip_layer_start: f32, + + /// SLG disabling point: (default: 0.2) + #[builder(default = "0.2")] + skip_layer_end: f32, +} + +impl ConfigBuilder { + /// add Lora model and clip strength to the prompt suffix + /// e.g. "" + pub fn lora_model(&mut self, lora_model: &Path, clip_strength: f32) -> &mut Self { + let folder = lora_model.parent().unwrap(); + let file_name = lora_model.file_stem().unwrap().to_str().unwrap().to_owned(); + self.prompt_suffix(format!("")); + self.lora_model = Some(folder.into()); + self + } + + pub fn n_threads(&mut self, value: i32) -> &mut Self { + self.n_threads = if value > 0 { + Some(value) + } else { + Some(num_cpus::get_physical() as i32) + }; + self + } + + fn validate(&self) -> Result<(), ConfigBuilderError> { + self.validate_model()?; + self.validate_output_dir() + } + + fn validate_model(&self) -> Result<(), ConfigBuilderError> { + self.model + .as_ref() + .or(self.diffusion_model.as_ref()) + .map(|_| ()) + .ok_or(ConfigBuilderError::UninitializedField( + "Model OR DiffusionModel must be valorized", + )) + } + + fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> { + let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir()); + let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1); + if is_dir == multiple_items { + Ok(()) + } else { + Err(ConfigBuilderError::ValidationError( + "When batch_count > 0, ouput should point to folder and viceversa".to_owned(), + )) + } + } +} + +impl Config { + unsafe fn build_sd_ctx(&self, vae_decode_only: bool) -> *mut sd_ctx_t { + new_sd_ctx( + self.model.as_ptr(), + self.clip_l.as_ptr(), + self.clip_g.as_ptr(), + self.t5xxl.as_ptr(), + self.diffusion_model.as_ptr(), + self.vae.as_ptr(), + self.taesd.as_ptr(), + self.control_net.as_ptr(), + self.lora_model.as_ptr(), + self.embeddings.as_ptr(), + self.stacked_id_embd.as_ptr(), + vae_decode_only, + self.vae_tiling, + self.free_params_immediately, + self.n_threads, + self.weight_type, + self.rng, + self.schedule, + self.clip_on_cpu, + self.control_net_cpu, + self.vae_on_cpu, + self.flash_attention, + ) + } + + unsafe fn upscaler_ctx(&self) -> Option<*mut upscaler_ctx_t> { + if self.upscale_model.is_none() || self.upscale_repeats == 0 { + None + } else { + let upscaler = new_upscaler_ctx( + self.upscale_model.as_ref().unwrap().as_ptr(), + self.n_threads, + ); + Some(upscaler) + } + } +} + +#[derive(Debug, Clone, Default)] +struct CLibString(CString); + +impl CLibString { + fn as_ptr(&self) -> *const c_char { + self.0.as_ptr() + } +} + +impl From<&str> for CLibString { + fn from(value: &str) -> Self { + Self(CString::new(value).unwrap()) + } +} + +impl From for CLibString { + fn from(value: String) -> Self { + Self(CString::new(value).unwrap()) + } +} + +#[derive(Debug, Clone, Default)] +struct CLibPath(CString); + +impl CLibPath { + fn as_ptr(&self) -> *const c_char { + self.0.as_ptr() + } +} + +impl From for CLibPath { + fn from(value: PathBuf) -> Self { + Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) + } +} + +impl From<&Path> for CLibPath { + fn from(value: &Path) -> Self { + Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) + } +} + +fn output_files(path: PathBuf, batch_size: i32) -> Vec { + if batch_size == 1 { + vec![path.into()] + } else { + (1..=batch_size) + .map(|id| path.join(format!("output_{id}.png")).into()) + .collect() + } +} + +unsafe fn upscale( + upscale_repeats: i32, + upscaler_ctx: Option<*mut upscaler_ctx_t>, + data: sd_image_t, +) -> Result { + match upscaler_ctx { + Some(upscaler_ctx) => { + let upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth + let mut current_image = data; + for _ in 0..upscale_repeats { + let upscaled_image = + diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor); + + if upscaled_image.data.is_null() { + return Err(DiffusionError::Upscaler); + } + + free(current_image.data as *mut c_void); + current_image = upscaled_image; + } + Ok(current_image) + } + None => Ok(data), + } +} + +/// Generate an image with a prompt +pub fn txt2img(mut config: Config) -> Result<(), DiffusionError> { + unsafe { + let prompt: CLibString = match &config.prompt_suffix { + Some(suffix) => format!("{} {suffix}", &config.prompt), + None => config.prompt.clone(), + } + .into(); + let sd_ctx = config.build_sd_ctx(true); + let upscaler_ctx = config.upscaler_ctx(); + let res = { + let slice = diffusion_rs_sys::txt2img( + sd_ctx, + prompt.as_ptr(), + config.negative_prompt.as_ptr(), + config.clip_skip as i32, + config.cfg_scale, + config.guidance, + config.width, + config.height, + config.sampling_method, + config.steps, + config.seed, + config.batch_count, + null(), + config.control_strength, + config.style_ratio, + config.normalize_input, + config.input_id_images.as_ptr(), + config.skip_layer.as_mut_ptr(), + config.skip_layer.len(), + config.slg_scale, + config.skip_layer_start, + config.skip_layer_end, + ); + if slice.is_null() { + return Err(DiffusionError::Forward); + } + let files = output_files(config.output, config.batch_count); + for (id, (img, path)) in slice::from_raw_parts(slice, config.batch_count as usize) + .iter() + .zip(files) + .enumerate() + { + match upscale(config.upscale_repeats, upscaler_ctx, *img) { + Ok(img) => { + let status = stbi_write_png_custom( + path.as_ptr(), + img.width as i32, + img.height as i32, + img.channel as i32, + img.data as *const c_void, + 0, + ); + if status == 0 { + return Err(DiffusionError::StoreImages(id, config.batch_count)); + } + } + Err(err) => { + return Err(err); + } + } + } + + //Clean-up slice section + free(slice as *mut c_void); + Ok(()) + }; + + //Clean-up CTX section + free_sd_ctx(sd_ctx); + if let Some(upscaler_ctx) = upscaler_ctx { + free_upscaler_ctx(upscaler_ctx); + } + res + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use crate::{api::ConfigBuilderError, util::download_file_hf_hub}; + + use super::{txt2img, ConfigBuilder}; + + #[test] + fn test_required_args_txt2img() { + assert!(ConfigBuilder::default().build().is_err()); + assert!(ConfigBuilder::default() + .model(PathBuf::from("./test.ckpt")) + .build() + .is_err()); + + assert!(ConfigBuilder::default() + .prompt("a lovely cat driving a sport car") + .build() + .is_err()); + + assert!(matches!( + ConfigBuilder::default() + .model(PathBuf::from("./test.ckpt")) + .prompt("a lovely cat driving a sport car") + .batch_count(10) + .build(), + Err(ConfigBuilderError::ValidationError(_)) + )); + + ConfigBuilder::default() + .model(PathBuf::from("./test.ckpt")) + .prompt("a lovely cat driving a sport car") + .build() + .unwrap(); + + ConfigBuilder::default() + .model(PathBuf::from("./test.ckpt")) + .prompt("a lovely duck drinking water from a bottle") + .batch_count(2) + .output(PathBuf::from("./")) + .build() + .unwrap(); + } + + #[ignore] + #[test] + fn test_txt2img() { + let model_path = + download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt") + .unwrap(); + + let upscaler_path = download_file_hf_hub( + "ximso/RealESRGAN_x4plus_anime_6B", + "RealESRGAN_x4plus_anime_6B.pth", + ) + .unwrap(); + let config = ConfigBuilder::default() + .model(model_path) + .prompt("a lovely duck drinking water from a bottle") + .output(PathBuf::from("./output_1.png")) + .upscale_model(upscaler_path) + .upscale_repeats(1) + .batch_count(1) + .build() + .unwrap(); + txt2img(config).unwrap(); + } +} diff --git a/src/txt2img_config.rs b/src/txt2img_config.rs new file mode 100644 index 0000000..bbef5a8 --- /dev/null +++ b/src/txt2img_config.rs @@ -0,0 +1,112 @@ +use crate::utils::{CLibPath, CLibString, ClipSkip}; +use derive_builder::Builder; + +/// Sampling methods +pub use diffusion_rs_sys::sample_method_t as SampleMethod; + +#[derive(Builder, Debug, Clone)] +#[builder(setter(into), build_fn(validate = "Self::validate"))] +/// txt2img config +pub struct Txt2ImgConfig { + /// Prompt to generate image from + pub prompt: String, + + /// Suffix that needs to be added to prompt (e.g. lora model) + #[builder(default = "Default::default()", private)] + pub lora_prompt_suffix: Vec, + + /// The negative prompt (default: "") + #[builder(default = "\"\".into()")] + pub negative_prompt: CLibString, + + /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) + /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x + #[builder(default = "ClipSkip::Unspecified")] + pub clip_skip: ClipSkip, + + /// Unconditional guidance scale (default: 7.0) + #[builder(default = "7.0")] + pub cfg_scale: f32, + + /// Guidance (default: 3.5) + #[builder(default = "3.5")] + pub guidance: f32, + + /// Image height, in pixel space (default: 512) + #[builder(default = "512")] + pub height: i32, + + /// Image width, in pixel space (default: 512) + #[builder(default = "512")] + pub width: i32, + + /// Sampling-method (default: EULER_A) + #[builder(default = "SampleMethod::EULER_A")] + pub sample_method: SampleMethod, + + /// Number of sample steps (default: 20) + #[builder(default = "20")] + pub sample_steps: i32, + + /// RNG seed (default: 42, use random seed for < 0) + #[builder(default = "42")] + pub seed: i64, + + /// Number of images to generate (default: 1) + #[builder(default = "1")] + pub batch_count: i32, + + /// Strength to apply Control Net (default: 0.9) + /// 1.0 corresponds to full destruction of information in init + #[builder(default = "0.9")] + pub control_strength: f32, + + /// Strength for keeping input identity (default: 20%) + #[builder(default = "20.0")] + pub style_ratio: f32, + + /// Normalize PHOTOMAKER input id images + #[builder(default = "false")] + pub normalize_input: bool, + + /// Path to PHOTOMAKER input id images dir + #[builder(default = "Default::default()")] + pub input_id_images: CLibPath, + + /// Layers to skip for SLG steps: (default: [7,8,9]) + #[builder(default = "vec![7, 8, 9]")] + pub skip_layer: Vec, + + /// skip layer guidance (SLG) scale, only for DiT models: (default: 0) + /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium + #[builder(default = "0.")] + pub slg_scale: f32, + + /// SLG enabling point: (default: 0.01) + #[builder(default = "0.01")] + pub skip_layer_start: f32, + + /// SLG disabling point: (default: 0.2) + #[builder(default = "0.2")] + pub skip_layer_end: f32, +} + +impl Txt2ImgConfigBuilder { + fn validate(&self) -> Result<(), Txt2ImgConfigBuilderError> { + self.validate_prompt() + } + + fn validate_prompt(&self) -> Result<(), Txt2ImgConfigBuilderError> { + self.prompt + .as_ref() + .map(|_| ()) + .ok_or(Txt2ImgConfigBuilderError::UninitializedField("Prompt")) + } + + pub fn add_lora_model(&mut self, filename: String, strength: f32) -> &mut Self { + self.lora_prompt_suffix + .get_or_insert_with(Vec::new) + .push(format!("")); + self + } +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..e70f15a --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,76 @@ +use std::ffi::c_char; + +use std::ffi::CString; +use std::path::Path; +use std::path::PathBuf; +use thiserror::Error; + +#[non_exhaustive] +#[derive(Error, Debug)] +/// Error that can occurs while forwarding models +pub enum DiffusionError { + #[error("The underling stablediffusion.cpp function returned NULL")] + Forward, + #[error("The underling stbi_write_image function returned 0 while saving image {0}/{1})")] + StoreImages(usize, i32), + #[error("The underling upscaler model returned a NULL image")] + Upscaler, + #[error("Failed to convert image buffer to rust type")] + SDImagetoRustImage, + // #[error("Free Params Immediately is set to true, which means that the params are freed after forward. This means that the model can only be used once")] + // FreeParamsImmediately, +} + +#[repr(i32)] +#[non_exhaustive] +#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)] +/// Ignore the lower X layers of CLIP network +pub enum ClipSkip { + /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x + #[default] + Unspecified = 0, + None = 1, + OneLayer = 2, +} + +#[derive(Debug, Clone, Default)] +pub struct CLibString(pub CString); + +impl CLibString { + pub fn as_ptr(&self) -> *const c_char { + self.0.as_ptr() + } +} + +impl From<&str> for CLibString { + fn from(value: &str) -> Self { + Self(CString::new(value).unwrap()) + } +} + +impl From for CLibString { + fn from(value: String) -> Self { + Self(CString::new(value).unwrap()) + } +} + +#[derive(Debug, Clone, Default)] +pub struct CLibPath(CString); + +impl CLibPath { + pub fn as_ptr(&self) -> *const c_char { + self.0.as_ptr() + } +} + +impl From for CLibPath { + fn from(value: PathBuf) -> Self { + Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) + } +} + +impl From<&Path> for CLibPath { + fn from(value: &Path) -> Self { + Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) + } +} From ac0212fca9dfe450aaa4382a733d62f1cb11e891 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Mon, 3 Feb 2025 18:32:07 -0500 Subject: [PATCH 03/33] add support for image crate conversions --- src/api.rs | 65 +++++++++++------------ src/img2img_config.rs | 0 src/model_config.rs | 11 +--- src/txt2img_config.rs | 11 ++-- src/utils.rs | 117 ++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 151 insertions(+), 53 deletions(-) delete mode 100644 src/img2img_config.rs diff --git a/src/api.rs b/src/api.rs index c7b134b..7c54c59 100644 --- a/src/api.rs +++ b/src/api.rs @@ -3,8 +3,6 @@ use std::ptr::null; use std::slice; use diffusion_rs_sys::sd_image_t; -use image::ImageBuffer; -use image::Rgb; use image::RgbImage; use libc::free; @@ -16,9 +14,10 @@ use crate::model_config::ModelConfig; use crate::txt2img_config::Txt2ImgConfig; use crate::utils::CLibString; use crate::utils::DiffusionError; +use crate::utils::SdImageContainer; struct ModelCtx { /// The underlying C context - raw_ctx: *mut sd_ctx_t, + raw_ctx: Option<*mut sd_ctx_t>, /// We keep the config around in case we need to refer to it pub model_config: ModelConfig, @@ -27,7 +26,7 @@ struct ModelCtx { impl ModelCtx { pub fn new(config: ModelConfig) -> Self { let raw_ctx = unsafe { - new_sd_ctx( + let ptr = new_sd_ctx( config.model.as_ptr(), config.clip_l.as_ptr(), config.clip_g.as_ptr(), @@ -50,7 +49,12 @@ impl ModelCtx { config.keep_control_net_cpu, config.keep_vae_on_cpu, config.flash_attention, - ) + ); + if ptr.is_null() { + None + } else { + Some(ptr) + } }; Self { @@ -60,10 +64,9 @@ impl ModelCtx { } pub fn destroy(&mut self) { - unsafe { - if !self.raw_ctx.is_null() { - free_sd_ctx(self.raw_ctx); - self.raw_ctx = std::ptr::null_mut(); + if let Some(ptr) = self.raw_ctx.take() { + unsafe { + free_sd_ctx(ptr); } } } @@ -87,11 +90,20 @@ impl ModelCtx { prompt.0.to_str().expect("Couldn't get string") ); - //controlnet support + //controlnet + let control_image: *const sd_image_t = match txt2img_config.control_cond { + Some(image) => { + let wrapper = SdImageContainer::try_from(image)?; + wrapper.as_ptr() + } + None => null(), + }; + + //run text to image let results: *mut sd_image_t = unsafe { diffusion_rs_sys::txt2img( - self.raw_ctx, + self.raw_ctx.ok_or(DiffusionError::NoContext)?, prompt.as_ptr(), txt2img_config.negative_prompt.as_ptr(), txt2img_config.clip_skip as i32, @@ -103,9 +115,9 @@ impl ModelCtx { txt2img_config.sample_steps, txt2img_config.seed, txt2img_config.batch_count, - null(), + control_image, txt2img_config.control_strength, - txt2img_config.style_ratio, + txt2img_config.style_strength, txt2img_config.normalize_input, txt2img_config.input_id_images.as_ptr(), txt2img_config.skip_layer.as_mut_ptr(), @@ -120,30 +132,13 @@ impl ModelCtx { return Err(DiffusionError::Forward); } - let result_images: Vec = unsafe { + let result_images: Vec = { let img_count = txt2img_config.batch_count as usize; - let images = slice::from_raw_parts(results, img_count); - let rgb_images: Result, DiffusionError> = images + let images = unsafe { slice::from_raw_parts(results, img_count) }; + images .iter() - .map(|sd_img| { - let len = (sd_img.width * sd_img.height * sd_img.channel) as usize; - let raw_pixels = slice::from_raw_parts(sd_img.data, len); - let buffer = raw_pixels.to_vec(); - let buffer = ImageBuffer::, _>::from_raw( - sd_img.width as u32, - sd_img.height as u32, - buffer, - ); - Ok(match buffer { - Some(buffer) => RgbImage::from(buffer), - None => return Err(DiffusionError::SDImagetoRustImage), - }) - }) - .collect(); - match rgb_images { - Ok(images) => images, - Err(e) => return Err(e), - } + .filter_map(|sd_img| RgbImage::try_from(SdImageContainer::from(*sd_img)).ok()) + .collect() }; //Clean-up slice section diff --git a/src/img2img_config.rs b/src/img2img_config.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/model_config.rs b/src/model_config.rs index f95a077..b410095 100644 --- a/src/model_config.rs +++ b/src/model_config.rs @@ -1,15 +1,6 @@ use derive_builder::Builder; -use crate::utils::CLibPath; - -/// Specify the range function -pub use diffusion_rs_sys::rng_type_t as RngFunction; - -/// Denoiser sigma schedule -pub use diffusion_rs_sys::schedule_t as Schedule; - -/// Weight type -pub use diffusion_rs_sys::sd_type_t as WeightType; +use crate::utils::{CLibPath, RngFunction, Schedule, WeightType}; #[derive(Builder, Debug, Clone)] #[builder(setter(into), build_fn(validate = "Self::validate"))] diff --git a/src/txt2img_config.rs b/src/txt2img_config.rs index bbef5a8..bd35f47 100644 --- a/src/txt2img_config.rs +++ b/src/txt2img_config.rs @@ -1,8 +1,6 @@ -use crate::utils::{CLibPath, CLibString, ClipSkip}; +use crate::utils::{CLibPath, CLibString, ClipSkip, SampleMethod}; use derive_builder::Builder; - -/// Sampling methods -pub use diffusion_rs_sys::sample_method_t as SampleMethod; +use image::RgbImage; #[derive(Builder, Debug, Clone)] #[builder(setter(into), build_fn(validate = "Self::validate"))] @@ -56,6 +54,9 @@ pub struct Txt2ImgConfig { #[builder(default = "1")] pub batch_count: i32, + #[builder(setter(into, strip_option))] + pub control_cond: Option, + /// Strength to apply Control Net (default: 0.9) /// 1.0 corresponds to full destruction of information in init #[builder(default = "0.9")] @@ -63,7 +64,7 @@ pub struct Txt2ImgConfig { /// Strength for keeping input identity (default: 20%) #[builder(default = "20.0")] - pub style_ratio: f32, + pub style_strength: f32, /// Normalize PHOTOMAKER input id images #[builder(default = "false")] diff --git a/src/utils.rs b/src/utils.rs index e70f15a..9c91c93 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,8 +1,11 @@ +use image::ImageBuffer; +use image::Rgb; +use image::RgbImage; use std::ffi::c_char; - use std::ffi::CString; use std::path::Path; use std::path::PathBuf; +use std::slice; use thiserror::Error; #[non_exhaustive] @@ -15,12 +18,23 @@ pub enum DiffusionError { StoreImages(usize, i32), #[error("The underling upscaler model returned a NULL image")] Upscaler, - #[error("Failed to convert image buffer to rust type")] - SDImagetoRustImage, + #[error("sd_ctx_t is None")] + NoContext, + #[error("SD image conversion error: {0}")] + SDImageError(#[from] SDImageError), // #[error("Free Params Immediately is set to true, which means that the params are freed after forward. This means that the model can only be used once")] // FreeParamsImmediately, } +#[non_exhaustive] +#[derive(Debug, thiserror::Error)] +pub enum SDImageError { + #[error("Failed to convert image buffer to Rust type")] + AllocationError, + #[error("The image buffer has a different length than expected")] + DifferentLength, +} + #[repr(i32)] #[non_exhaustive] #[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)] @@ -74,3 +88,100 @@ impl From<&Path> for CLibPath { Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) } } + +/// Specify the range function +pub use diffusion_rs_sys::rng_type_t as RngFunction; + +/// Denoiser sigma schedule +pub use diffusion_rs_sys::schedule_t as Schedule; + +/// Weight type +pub use diffusion_rs_sys::sd_type_t as WeightType; + +/// Sampling methods +pub use diffusion_rs_sys::sample_method_t as SampleMethod; + +/// Image buffer Type +pub use diffusion_rs_sys::sd_image_t; + +#[derive(Debug, Clone)] +pub struct SdImageContainer { + // Wrap the raw external type. + inner: sd_image_t, +} + +impl SdImageContainer { + pub fn as_ptr(&self) -> *const sd_image_t { + &self.inner + } +} + +impl From for SdImageContainer { + fn from(inner: sd_image_t) -> Self { + Self { inner } + } +} + +impl TryFrom for SdImageContainer { + type Error = SDImageError; + + fn try_from(img: RgbImage) -> Result { + let (width, height) = img.dimensions(); + // For an RGB image, we have 3 channels. + let channel = 3u32; + let expected_len = (width * height * channel) as usize; + + // Convert the image into its raw pixel data (a Vec). + let pixel_data: Vec = img.into_raw(); + + // Ensure that the pixel data is of the expected length. + if pixel_data.len() != expected_len { + return Err(SDImageError::DifferentLength); + } + + let data_ptr = unsafe { + let ptr = libc::malloc(expected_len) as *mut u8; + if ptr.is_null() { + return Err(SDImageError::AllocationError); + } + std::ptr::copy_nonoverlapping(pixel_data.as_ptr(), ptr, expected_len); + ptr + }; + + Ok(SdImageContainer { + inner: sd_image_t { + width, + height, + channel, + data: data_ptr, + }, + }) + } +} + +impl Drop for SdImageContainer { + fn drop(&mut self) { + unsafe { + libc::free(self.inner.data as *mut libc::c_void); + } + } +} + +impl TryFrom for RgbImage { + type Error = SDImageError; + + fn try_from(sd_image: SdImageContainer) -> Result { + let len = (sd_image.inner.width * sd_image.inner.height * sd_image.inner.channel) as usize; + let raw_pixels = unsafe { slice::from_raw_parts(sd_image.inner.data, len) }; + let buffer = raw_pixels.to_vec(); + let buffer = ImageBuffer::, _>::from_raw( + sd_image.inner.width as u32, + sd_image.inner.height as u32, + buffer, + ); + Ok(match buffer { + Some(buffer) => RgbImage::from(buffer), + None => return Err(SDImageError::AllocationError), + }) + } +} From 81e7365f6be654d18ef439e8c111089b10d46028 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Mon, 3 Feb 2025 20:02:08 -0500 Subject: [PATCH 04/33] fix control net issues, (control net does not work because of sdcpp) --- src/api.rs | 58 ++++++++++++++++++++++++++++--------------- src/txt2img_config.rs | 2 +- src/utils.rs | 4 ++- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/api.rs b/src/api.rs index 7c54c59..c108158 100644 --- a/src/api.rs +++ b/src/api.rs @@ -24,7 +24,7 @@ struct ModelCtx { } impl ModelCtx { - pub fn new(config: ModelConfig) -> Self { + pub fn new(config: ModelConfig) -> Result { let raw_ctx = unsafe { let ptr = new_sd_ctx( config.model.as_ptr(), @@ -51,16 +51,16 @@ impl ModelCtx { config.flash_attention, ); if ptr.is_null() { - None + return Err(DiffusionError::NewContextFailure); } else { Some(ptr) } }; - Self { + Ok(Self { raw_ctx, model_config: config, - } + }) } pub fn destroy(&mut self) { @@ -92,12 +92,16 @@ impl ModelCtx { //controlnet - let control_image: *const sd_image_t = match txt2img_config.control_cond { - Some(image) => { - let wrapper = SdImageContainer::try_from(image)?; - wrapper.as_ptr() + let control_image = if self.model_config.control_net.as_ptr().is_null() { + match txt2img_config.control_cond { + Some(image) => { + let wrapper = SdImageContainer::try_from(image)?; + wrapper.as_ptr() + } + None => null(), } - None => null(), + } else { + null() }; //run text to image @@ -162,6 +166,7 @@ mod tests { pub use diffusion_rs_sys::sample_method_t as SampleMethod; /// Denoiser sigma schedule pub use diffusion_rs_sys::schedule_t as Schedule; + use image::ImageReader; use crate::utils::ClipSkip; use crate::{model_config::ModelConfigBuilder, txt2img_config::Txt2ImgConfigBuilder}; @@ -207,15 +212,27 @@ mod tests { #[test] fn test_txt2img_success() { - let config = ModelConfigBuilder::default() - .model(PathBuf::from("./models/mistoonAnime_v30.safetensors")) - .lora_model_dir(PathBuf::from("./models/loras")) - .taesd(PathBuf::from("./models/taesd1.safetensors")) - .flash_attention(true) - .schedule(Schedule::AYS) - .build() - .expect("Failed to build model config"); - let mut ctx = ModelCtx::new(config.clone()); + let control_image = ImageReader::open("openposetest.png") + .expect("Failed to open image") + .decode() + .expect("Failed to decode image") + .into_rgb8(); + + let mut ctx = ModelCtx::new( + ModelConfigBuilder::default() + .model(PathBuf::from("./models/mistoonAnime_v30.safetensors")) + .lora_model_dir(PathBuf::from("./models/loras")) + .taesd(PathBuf::from("./models/taesd1.safetensors")) + .control_net(PathBuf::from( + "./models/controlnet/sd15openpose11.safetensors", + )) + .flash_attention(true) + .schedule(Schedule::AYS) + .build() + .expect("Failed to build model config"), + ) + .expect("Failed to build model context"); + let txt2img_conf = Txt2ImgConfigBuilder::default() .prompt("masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, portrait, purple skin") .add_lora_model("pcm_sd15_lcmlike_lora_converted".to_owned(), 1.0) @@ -229,6 +246,7 @@ mod tests { .expect("Failed to build txt2img config"); let txt2img_conf2 = Txt2ImgConfigBuilder::default() .prompt("masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, portrait, golden skin") + .control_cond(control_image) .add_lora_model("pcm_sd15_lcmlike_lora_converted".to_owned(), 1.0) .sample_steps(2) .sample_method(SampleMethod::LCM) @@ -271,7 +289,7 @@ mod tests { .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) .build() .unwrap(); - let mut ctx = ModelCtx::new(config); + let mut ctx = ModelCtx::new(config).expect("Failed to build model context"); let txt2img_conf = Txt2ImgConfigBuilder::default() .prompt("test prompt") .sample_steps(1) @@ -290,7 +308,7 @@ mod tests { .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) .build() .unwrap(); - let mut ctx = ModelCtx::new(config); + let mut ctx = ModelCtx::new(config).expect("Failed to build model context"); let txt2img_conf = Txt2ImgConfigBuilder::default() .prompt("multi-image prompt") .sample_steps(1) diff --git a/src/txt2img_config.rs b/src/txt2img_config.rs index bd35f47..63bbea6 100644 --- a/src/txt2img_config.rs +++ b/src/txt2img_config.rs @@ -54,7 +54,7 @@ pub struct Txt2ImgConfig { #[builder(default = "1")] pub batch_count: i32, - #[builder(setter(into, strip_option))] + #[builder(setter(strip_option), default)] pub control_cond: Option, /// Strength to apply Control Net (default: 0.9) diff --git a/src/utils.rs b/src/utils.rs index 9c91c93..cce62e4 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -18,8 +18,10 @@ pub enum DiffusionError { StoreImages(usize, i32), #[error("The underling upscaler model returned a NULL image")] Upscaler, - #[error("sd_ctx_t is None")] + #[error("raw_ctx is None")] NoContext, + #[error("new_sd_ctx returned null")] + NewContextFailure, #[error("SD image conversion error: {0}")] SDImageError(#[from] SDImageError), // #[error("Free Params Immediately is set to true, which means that the params are freed after forward. This means that the model can only be used once")] From c0988c3e8d006250e8141e5c34e1b2465bd3d449 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Mon, 3 Feb 2025 20:09:24 -0500 Subject: [PATCH 05/33] fix dead code warnings --- src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib.rs b/src/lib.rs index 6a4c58e..fdec561 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ #![doc = include_str!("../README.md")] +#![allow(dead_code)] pub mod api; pub mod model_config; From 5765afaa7c7a536aeb6d25eb2373246fa56d9441 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Mon, 3 Feb 2025 20:28:34 -0500 Subject: [PATCH 06/33] public model ctx --- src/api.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/api.rs b/src/api.rs index c108158..01fa2fc 100644 --- a/src/api.rs +++ b/src/api.rs @@ -15,7 +15,7 @@ use crate::txt2img_config::Txt2ImgConfig; use crate::utils::CLibString; use crate::utils::DiffusionError; use crate::utils::SdImageContainer; -struct ModelCtx { +pub struct ModelCtx { /// The underlying C context raw_ctx: Option<*mut sd_ctx_t>, From 8f7c25239ce33a7f5d68f0846d8ac4ba90b2e239 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Tue, 4 Feb 2025 11:18:17 -0500 Subject: [PATCH 07/33] control net cleanup --- src/api.rs | 171 ++++++++++++++++++++++-------------------- src/txt2img_config.rs | 2 +- src/utils.rs | 4 +- 3 files changed, 95 insertions(+), 82 deletions(-) diff --git a/src/api.rs b/src/api.rs index 01fa2fc..fd77457 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,20 +1,19 @@ -use std::ffi::c_void; -use std::ptr::null; -use std::slice; - -use diffusion_rs_sys::sd_image_t; -use image::RgbImage; -use libc::free; - -use diffusion_rs_sys::free_sd_ctx; -use diffusion_rs_sys::new_sd_ctx; -use diffusion_rs_sys::sd_ctx_t; - use crate::model_config::ModelConfig; use crate::txt2img_config::Txt2ImgConfig; use crate::utils::CLibString; use crate::utils::DiffusionError; use crate::utils::SdImageContainer; +use diffusion_rs_sys::free_sd_ctx; +use diffusion_rs_sys::new_sd_ctx; +use diffusion_rs_sys::sd_ctx_t; +use diffusion_rs_sys::sd_image_t; +use diffusion_rs_sys::strlen; +use image::RgbImage; +use libc::free; +use std::ffi::c_void; +use std::ptr::null; +use std::slice; + pub struct ModelCtx { /// The underlying C context raw_ctx: Option<*mut sd_ctx_t>, @@ -63,14 +62,6 @@ impl ModelCtx { }) } - pub fn destroy(&mut self) { - if let Some(ptr) = self.raw_ctx.take() { - unsafe { - free_sd_ctx(ptr); - } - } - } - pub fn txt2img( &mut self, mut txt2img_config: Txt2ImgConfig, @@ -92,16 +83,20 @@ impl ModelCtx { //controlnet - let control_image = if self.model_config.control_net.as_ptr().is_null() { - match txt2img_config.control_cond { - Some(image) => { - let wrapper = SdImageContainer::try_from(image)?; - wrapper.as_ptr() + let control_image = match txt2img_config.control_cond { + Some(image) => { + match unsafe { strlen(self.model_config.control_net.as_ptr()) as usize > 0 } { + true => { + let wrapper = SdImageContainer::try_from(image)?; + wrapper.as_ptr() + } + false => { + println!("Control net model is null, setting control image to null"); + null() + } } - None => null(), } - } else { - null() + None => null(), }; //run text to image @@ -156,20 +151,20 @@ impl ModelCtx { /// Automatic cleanup on drop impl Drop for ModelCtx { fn drop(&mut self) { - self.destroy(); + match self.raw_ctx { + Some(ptr) => unsafe { + free_sd_ctx(ptr); + }, + None => {} + } } } #[cfg(test)] mod tests { - /// Sampling methods - pub use diffusion_rs_sys::sample_method_t as SampleMethod; - /// Denoiser sigma schedule - pub use diffusion_rs_sys::schedule_t as Schedule; - use image::ImageReader; - - use crate::utils::ClipSkip; + use crate::utils::{ClipSkip, SampleMethod, Schedule, WeightType}; use crate::{model_config::ModelConfigBuilder, txt2img_config::Txt2ImgConfigBuilder}; + use image::ImageReader; use super::*; use std::path::PathBuf; @@ -212,7 +207,7 @@ mod tests { #[test] fn test_txt2img_success() { - let control_image = ImageReader::open("openposetest.png") + let control_image = ImageReader::open("canny-384x.jpg") .expect("Failed to open image") .decode() .expect("Failed to decode image") @@ -224,8 +219,9 @@ mod tests { .lora_model_dir(PathBuf::from("./models/loras")) .taesd(PathBuf::from("./models/taesd1.safetensors")) .control_net(PathBuf::from( - "./models/controlnet/sd15openpose11.safetensors", + "./models/controlnet/control_canny-fp16.safetensors", )) + //.weight_type(WeightType::SD_TYPE_Q4_1) .flash_attention(true) .schedule(Schedule::AYS) .build() @@ -233,53 +229,68 @@ mod tests { ) .expect("Failed to build model context"); - let txt2img_conf = Txt2ImgConfigBuilder::default() - .prompt("masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, portrait, purple skin") - .add_lora_model("pcm_sd15_lcmlike_lora_converted".to_owned(), 1.0) - .sample_steps(2) - .sample_method(SampleMethod::LCM) - .cfg_scale(1.0) - .height(256) - .width(256) - .clip_skip(ClipSkip::OneLayer) - .build() - .expect("Failed to build txt2img config"); - let txt2img_conf2 = Txt2ImgConfigBuilder::default() - .prompt("masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, portrait, golden skin") + let result = ctx + .txt2img(Txt2ImgConfigBuilder::default() + .prompt("masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk") .control_cond(control_image) - .add_lora_model("pcm_sd15_lcmlike_lora_converted".to_owned(), 1.0) - .sample_steps(2) + .control_strength(0.4) + .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) + .sample_steps(6) .sample_method(SampleMethod::LCM) .cfg_scale(1.0) - .height(256) - .width(256) + .height(384) + .width(384) .clip_skip(ClipSkip::OneLayer) + .batch_count(1) .build() - .expect("Failed to build txt2img config"); - let result = ctx.txt2img(txt2img_conf); - let result2 = ctx.txt2img(txt2img_conf2); - match result { - Ok(images) => { - //save image for testing - images.iter().enumerate().for_each(|(i, img)| { - img.save(format!("./test_image_{}.png", i)).unwrap(); - }); - match result2 { - Ok(images) => { - //save image for testing - images.iter().enumerate().for_each(|(i, img)| { - img.save(format!("./test_image2_{}.png", i)).unwrap(); - }); - } - Err(e) => { - panic!("Error: {:?}", e); - } - } - } - Err(e) => { - panic!("Error: {:?}", e); - } - }; + .expect("Failed to build txt2img config 1")) + .expect("Failed to generate image 1"); + + result.iter().enumerate().for_each(|(i, img)| { + img.save(format!("./test_image_{}.png", i)).unwrap(); + }); + + // let result2 = ctx + // .txt2img(Txt2ImgConfigBuilder::default() + // .prompt("masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, white skin, white dress, blue eyes") + // .control_cond(control_image1) + // .control_strength(0.4) + // .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) + // .sample_steps(6) + // .sample_method(SampleMethod::LCM) + // .cfg_scale(1.0) + // .height(384) + // .width(384) + // .clip_skip(ClipSkip::OneLayer) + // .batch_count(1) + // .build() + // .expect("Failed to build txt2img config 2")) + // .expect("Failed to generate image 2"); + + // result2.iter().enumerate().for_each(|(i, img)| { + // img.save(format!("./test_image2_{}.png", i)).unwrap(); + // }); + + // let result3 = ctx + // .txt2img(Txt2ImgConfigBuilder::default() + // .prompt("masterpiece, best quality, absurdres, 1girl, short hair, brown hair, dark skin, dark green suit, brown eyes, clenched_teeth") + // .control_cond(control_image2) + // .control_strength(0.4) + // .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) + // .sample_steps(6) + // .sample_method(SampleMethod::LCM) + // .cfg_scale(1.0) + // .height(384) + // .width(384) + // .clip_skip(ClipSkip::OneLayer) + // .batch_count(1) + // .build() + // .expect("Failed to build txt2img config 2")) + // .expect("Failed to generate image 2"); + + // result3.iter().enumerate().for_each(|(i, img)| { + // img.save(format!("./test_image3_{}.png", i)).unwrap(); + // }); } #[test] diff --git a/src/txt2img_config.rs b/src/txt2img_config.rs index 63bbea6..36afff4 100644 --- a/src/txt2img_config.rs +++ b/src/txt2img_config.rs @@ -104,7 +104,7 @@ impl Txt2ImgConfigBuilder { .ok_or(Txt2ImgConfigBuilderError::UninitializedField("Prompt")) } - pub fn add_lora_model(&mut self, filename: String, strength: f32) -> &mut Self { + pub fn add_lora_model(&mut self, filename: &str, strength: f32) -> &mut Self { self.lora_prompt_suffix .get_or_insert_with(Vec::new) .push(format!("")); diff --git a/src/utils.rs b/src/utils.rs index cce62e4..6810847 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,7 +1,9 @@ use image::ImageBuffer; use image::Rgb; use image::RgbImage; +use libc::free; use std::ffi::c_char; +use std::ffi::c_void; use std::ffi::CString; use std::path::Path; use std::path::PathBuf; @@ -164,7 +166,7 @@ impl TryFrom for SdImageContainer { impl Drop for SdImageContainer { fn drop(&mut self) { unsafe { - libc::free(self.inner.data as *mut libc::c_void); + free(self.inner.data as *mut c_void); } } } From 62d6b1b8468575450cb612bfb5d917acc81f66db Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Tue, 4 Feb 2025 11:28:47 -0500 Subject: [PATCH 08/33] add images for testing in gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index ad33e30..eaba1a9 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,5 @@ bin/act *.png .idea/ -models/ \ No newline at end of file +models/ +images/ \ No newline at end of file From f3babbea1e63766ab4538eb5712bedc91acf22e3 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Tue, 4 Feb 2025 16:44:22 -0500 Subject: [PATCH 09/33] test expansion and remove hf hub dependency --- Cargo.toml | 1 - src/api.rs | 118 ++++++++++++++++++++++++++++++----------------------- 2 files changed, 68 insertions(+), 51 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c4df08c..e74009b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,6 @@ documentation = "https://docs.rs/diffusion-rs" [dependencies] derive_builder = "0.20.2" diffusion-rs-sys = { path = "sys", version = "0.1.6" } -hf-hub = { version = "0.4.0", default-features = false, features = ["ureq"] } image = "0.25.5" libc = "0.2.161" num_cpus = "1.16.0" diff --git a/src/api.rs b/src/api.rs index fd77457..8a7bd4a 100644 --- a/src/api.rs +++ b/src/api.rs @@ -162,11 +162,10 @@ impl Drop for ModelCtx { #[cfg(test)] mod tests { + use super::*; use crate::utils::{ClipSkip, SampleMethod, Schedule, WeightType}; use crate::{model_config::ModelConfigBuilder, txt2img_config::Txt2ImgConfigBuilder}; use image::ImageReader; - - use super::*; use std::path::PathBuf; #[test] @@ -207,12 +206,28 @@ mod tests { #[test] fn test_txt2img_success() { - let control_image = ImageReader::open("canny-384x.jpg") + let control_image1 = ImageReader::open("./images/canny-384x.jpg") + .expect("Failed to open image") + .decode() + .expect("Failed to decode image") + .into_rgb8(); + + let control_image2 = ImageReader::open("./images/canny-384x.jpg") .expect("Failed to open image") .decode() .expect("Failed to decode image") .into_rgb8(); + let control_image3 = ImageReader::open("./images/canny-384x.jpg") + .expect("Failed to open image") + .decode() + .expect("Failed to decode image") + .into_rgb8(); + + let resolution = 384; + let sample_steps = 4; + let control_strength = 0.4; + let mut ctx = ModelCtx::new( ModelConfigBuilder::default() .model(PathBuf::from("./models/mistoonAnime_v30.safetensors")) @@ -232,14 +247,14 @@ mod tests { let result = ctx .txt2img(Txt2ImgConfigBuilder::default() .prompt("masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk") - .control_cond(control_image) - .control_strength(0.4) + .control_cond(control_image1) + .control_strength(control_strength) .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) - .sample_steps(6) + .sample_steps(sample_steps) .sample_method(SampleMethod::LCM) .cfg_scale(1.0) - .height(384) - .width(384) + .height(resolution) + .width(resolution) .clip_skip(ClipSkip::OneLayer) .batch_count(1) .build() @@ -247,50 +262,53 @@ mod tests { .expect("Failed to generate image 1"); result.iter().enumerate().for_each(|(i, img)| { - img.save(format!("./test_image_{}.png", i)).unwrap(); + img.save(format!("./images/test_1_{}x_{}.png", resolution, i)) + .unwrap(); }); - // let result2 = ctx - // .txt2img(Txt2ImgConfigBuilder::default() - // .prompt("masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, white skin, white dress, blue eyes") - // .control_cond(control_image1) - // .control_strength(0.4) - // .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) - // .sample_steps(6) - // .sample_method(SampleMethod::LCM) - // .cfg_scale(1.0) - // .height(384) - // .width(384) - // .clip_skip(ClipSkip::OneLayer) - // .batch_count(1) - // .build() - // .expect("Failed to build txt2img config 2")) - // .expect("Failed to generate image 2"); - - // result2.iter().enumerate().for_each(|(i, img)| { - // img.save(format!("./test_image2_{}.png", i)).unwrap(); - // }); - - // let result3 = ctx - // .txt2img(Txt2ImgConfigBuilder::default() - // .prompt("masterpiece, best quality, absurdres, 1girl, short hair, brown hair, dark skin, dark green suit, brown eyes, clenched_teeth") - // .control_cond(control_image2) - // .control_strength(0.4) - // .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) - // .sample_steps(6) - // .sample_method(SampleMethod::LCM) - // .cfg_scale(1.0) - // .height(384) - // .width(384) - // .clip_skip(ClipSkip::OneLayer) - // .batch_count(1) - // .build() - // .expect("Failed to build txt2img config 2")) - // .expect("Failed to generate image 2"); - - // result3.iter().enumerate().for_each(|(i, img)| { - // img.save(format!("./test_image3_{}.png", i)).unwrap(); - // }); + let result2 = ctx + .txt2img(Txt2ImgConfigBuilder::default() + .prompt("masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy") + .control_cond(control_image2) + .control_strength(control_strength) + .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) + .sample_steps(sample_steps) + .sample_method(SampleMethod::LCM) + .cfg_scale(1.0) + .height(resolution) + .width(resolution) + .clip_skip(ClipSkip::OneLayer) + .batch_count(1) + .build() + .expect("Failed to build txt2img config 1")) + .expect("Failed to generate image 1"); + + result2.iter().enumerate().for_each(|(i, img)| { + img.save(format!("./images/test_2_{}x_{}.png", resolution, i)) + .unwrap(); + }); + + let result3 = ctx + .txt2img(Txt2ImgConfigBuilder::default() + .prompt("masterpiece, best quality, absurdres, 1girl, medium hair, brown hair, green eyes, dark skin, dark green sweater, cat ears, nyan, sexy") + .control_cond(control_image3) + .control_strength(control_strength) + .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) + .sample_steps(sample_steps) + .sample_method(SampleMethod::LCM) + .cfg_scale(1.0) + .height(resolution) + .width(resolution) + .clip_skip(ClipSkip::OneLayer) + .batch_count(1) + .build() + .expect("Failed to build txt2img config 1")) + .expect("Failed to generate image 1"); + + result3.iter().enumerate().for_each(|(i, img)| { + img.save(format!("./images/test_3_{}x_{}.png", resolution, i)) + .unwrap(); + }); } #[test] From 0aa630106c78fb190f7fced88b69c02ae051036b Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Thu, 6 Feb 2025 17:28:37 -0500 Subject: [PATCH 10/33] simplify controlnet image logic, add original logging --- src/api.rs | 110 ++++++++++++++++++------------------------ src/txt2img_config.rs | 6 +-- src/utils.rs | 102 ++++++++++----------------------------- 3 files changed, 76 insertions(+), 142 deletions(-) diff --git a/src/api.rs b/src/api.rs index 8a7bd4a..ab1067c 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,15 +1,9 @@ use crate::model_config::ModelConfig; use crate::txt2img_config::Txt2ImgConfig; -use crate::utils::CLibString; -use crate::utils::DiffusionError; -use crate::utils::SdImageContainer; -use diffusion_rs_sys::free_sd_ctx; -use diffusion_rs_sys::new_sd_ctx; -use diffusion_rs_sys::sd_ctx_t; -use diffusion_rs_sys::sd_image_t; -use diffusion_rs_sys::strlen; +use crate::utils::{convert_image, setup_logging, CLibString, DiffusionError}; +use diffusion_rs_sys::{free_sd_ctx, new_sd_ctx, sd_ctx_t, sd_image_t, txt2img}; use image::RgbImage; -use libc::free; +use libc::{free, strlen}; use std::ffi::c_void; use std::ptr::null; use std::slice; @@ -24,6 +18,8 @@ pub struct ModelCtx { impl ModelCtx { pub fn new(config: ModelConfig) -> Result { + setup_logging(); + let raw_ctx = unsafe { let ptr = new_sd_ctx( config.model.as_ptr(), @@ -75,21 +71,16 @@ impl ModelCtx { prompt.into() }; - //print prompt for debugging - println!( - "Prompt: {:?}", - prompt.0.to_str().expect("Couldn't get string") - ); - //controlnet - - let control_image = match txt2img_config.control_cond { + let control_image: *const sd_image_t = match txt2img_config.control_cond { Some(image) => { match unsafe { strlen(self.model_config.control_net.as_ptr()) as usize > 0 } { - true => { - let wrapper = SdImageContainer::try_from(image)?; - wrapper.as_ptr() - } + true => &sd_image_t { + data: image.as_ptr() as *mut u8, + width: image.width(), + height: image.height(), + channel: 3, + }, false => { println!("Control net model is null, setting control image to null"); null() @@ -101,7 +92,7 @@ impl ModelCtx { //run text to image let results: *mut sd_image_t = unsafe { - diffusion_rs_sys::txt2img( + txt2img( self.raw_ctx.ok_or(DiffusionError::NoContext)?, prompt.as_ptr(), txt2img_config.negative_prompt.as_ptr(), @@ -136,7 +127,7 @@ impl ModelCtx { let images = unsafe { slice::from_raw_parts(results, img_count) }; images .iter() - .filter_map(|sd_img| RgbImage::try_from(SdImageContainer::from(*sd_img)).ok()) + .filter_map(|sd_img| convert_image(sd_img).ok()) .collect() }; @@ -163,7 +154,7 @@ impl Drop for ModelCtx { #[cfg(test)] mod tests { use super::*; - use crate::utils::{ClipSkip, SampleMethod, Schedule, WeightType}; + use crate::utils::{ClipSkip, RngFunction, SampleMethod, Schedule, WeightType}; use crate::{model_config::ModelConfigBuilder, txt2img_config::Txt2ImgConfigBuilder}; use image::ImageReader; use std::path::PathBuf; @@ -206,25 +197,14 @@ mod tests { #[test] fn test_txt2img_success() { - let control_image1 = ImageReader::open("./images/canny-384x.jpg") - .expect("Failed to open image") - .decode() - .expect("Failed to decode image") - .into_rgb8(); + let resolution: i32 = 384; - let control_image2 = ImageReader::open("./images/canny-384x.jpg") - .expect("Failed to open image") - .decode() - .expect("Failed to decode image") - .into_rgb8(); - - let control_image3 = ImageReader::open("./images/canny-384x.jpg") + let control_image1 = ImageReader::open("./images/canny-384x.jpg") .expect("Failed to open image") .decode() .expect("Failed to decode image") .into_rgb8(); - let resolution = 384; let sample_steps = 4; let control_strength = 0.4; @@ -236,8 +216,10 @@ mod tests { .control_net(PathBuf::from( "./models/controlnet/control_canny-fp16.safetensors", )) - //.weight_type(WeightType::SD_TYPE_Q4_1) + .weight_type(WeightType::SD_TYPE_F16) .flash_attention(true) + //.rng_type(RngFunction::STD_DEFAULT_RNG) + .vae_decode_only(true) .schedule(Schedule::AYS) .build() .expect("Failed to build model config"), @@ -247,12 +229,13 @@ mod tests { let result = ctx .txt2img(Txt2ImgConfigBuilder::default() .prompt("masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk") - .control_cond(control_image1) + .control_cond(&control_image1) .control_strength(control_strength) .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) .sample_steps(sample_steps) .sample_method(SampleMethod::LCM) .cfg_scale(1.0) + .guidance(0.0) .height(resolution) .width(resolution) .clip_skip(ClipSkip::OneLayer) @@ -269,46 +252,47 @@ mod tests { let result2 = ctx .txt2img(Txt2ImgConfigBuilder::default() .prompt("masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy") - .control_cond(control_image2) + .control_cond(&control_image1) .control_strength(control_strength) .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) .sample_steps(sample_steps) .sample_method(SampleMethod::LCM) .cfg_scale(1.0) + .guidance(0.0) .height(resolution) .width(resolution) .clip_skip(ClipSkip::OneLayer) .batch_count(1) .build() - .expect("Failed to build txt2img config 1")) - .expect("Failed to generate image 1"); + .expect("Failed to build txt2img config 2")) + .expect("Failed to generate image 2"); result2.iter().enumerate().for_each(|(i, img)| { img.save(format!("./images/test_2_{}x_{}.png", resolution, i)) .unwrap(); }); - let result3 = ctx - .txt2img(Txt2ImgConfigBuilder::default() - .prompt("masterpiece, best quality, absurdres, 1girl, medium hair, brown hair, green eyes, dark skin, dark green sweater, cat ears, nyan, sexy") - .control_cond(control_image3) - .control_strength(control_strength) - .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) - .sample_steps(sample_steps) - .sample_method(SampleMethod::LCM) - .cfg_scale(1.0) - .height(resolution) - .width(resolution) - .clip_skip(ClipSkip::OneLayer) - .batch_count(1) - .build() - .expect("Failed to build txt2img config 1")) - .expect("Failed to generate image 1"); - - result3.iter().enumerate().for_each(|(i, img)| { - img.save(format!("./images/test_3_{}x_{}.png", resolution, i)) - .unwrap(); - }); + // let result3 = ctx + // .txt2img(Txt2ImgConfigBuilder::default() + // .prompt("masterpiece, best quality, absurdres, 1girl, medium hair, brown hair, green eyes, dark skin, dark green sweater, cat ears, nyan, sexy") + // .control_cond(control_image3) + // .control_strength(control_strength) + // .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) + // .sample_steps(sample_steps) + // .sample_method(SampleMethod::LCM) + // .cfg_scale(1.0) + // .height(resolution) + // .width(resolution) + // .clip_skip(ClipSkip::OneLayer) + // .batch_count(1) + // .build() + // .expect("Failed to build txt2img config 1")) + // .expect("Failed to generate image 1"); + + // result3.iter().enumerate().for_each(|(i, img)| { + // img.save(format!("./images/test_3_{}x_{}.png", resolution, i)) + // .unwrap(); + // }); } #[test] diff --git a/src/txt2img_config.rs b/src/txt2img_config.rs index 36afff4..857e975 100644 --- a/src/txt2img_config.rs +++ b/src/txt2img_config.rs @@ -5,7 +5,7 @@ use image::RgbImage; #[derive(Builder, Debug, Clone)] #[builder(setter(into), build_fn(validate = "Self::validate"))] /// txt2img config -pub struct Txt2ImgConfig { +pub struct Txt2ImgConfig<'a> { /// Prompt to generate image from pub prompt: String, @@ -55,7 +55,7 @@ pub struct Txt2ImgConfig { pub batch_count: i32, #[builder(setter(strip_option), default)] - pub control_cond: Option, + pub control_cond: Option<&'a RgbImage>, /// Strength to apply Control Net (default: 0.9) /// 1.0 corresponds to full destruction of information in init @@ -92,7 +92,7 @@ pub struct Txt2ImgConfig { pub skip_layer_end: f32, } -impl Txt2ImgConfigBuilder { +impl<'a> Txt2ImgConfigBuilder<'a> { fn validate(&self) -> Result<(), Txt2ImgConfigBuilderError> { self.validate_prompt() } diff --git a/src/utils.rs b/src/utils.rs index 6810847..a1abde9 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,10 +1,15 @@ +use diffusion_rs_sys::sd_get_system_info; +use diffusion_rs_sys::sd_log_level_t; +use diffusion_rs_sys::sd_set_log_callback; use image::ImageBuffer; use image::Rgb; use image::RgbImage; use libc::free; use std::ffi::c_char; use std::ffi::c_void; +use std::ffi::CStr; use std::ffi::CString; + use std::path::Path; use std::path::PathBuf; use std::slice; @@ -105,87 +110,32 @@ pub use diffusion_rs_sys::sd_type_t as WeightType; /// Sampling methods pub use diffusion_rs_sys::sample_method_t as SampleMethod; -/// Image buffer Type -pub use diffusion_rs_sys::sd_image_t; - -#[derive(Debug, Clone)] -pub struct SdImageContainer { - // Wrap the raw external type. - inner: sd_image_t, -} - -impl SdImageContainer { - pub fn as_ptr(&self) -> *const sd_image_t { - &self.inner - } -} - -impl From for SdImageContainer { - fn from(inner: sd_image_t) -> Self { - Self { inner } - } -} - -impl TryFrom for SdImageContainer { - type Error = SDImageError; - - fn try_from(img: RgbImage) -> Result { - let (width, height) = img.dimensions(); - // For an RGB image, we have 3 channels. - let channel = 3u32; - let expected_len = (width * height * channel) as usize; - - // Convert the image into its raw pixel data (a Vec). - let pixel_data: Vec = img.into_raw(); - - // Ensure that the pixel data is of the expected length. - if pixel_data.len() != expected_len { - return Err(SDImageError::DifferentLength); - } - - let data_ptr = unsafe { - let ptr = libc::malloc(expected_len) as *mut u8; - if ptr.is_null() { - return Err(SDImageError::AllocationError); - } - std::ptr::copy_nonoverlapping(pixel_data.as_ptr(), ptr, expected_len); - ptr - }; - - Ok(SdImageContainer { - inner: sd_image_t { - width, - height, - channel, - data: data_ptr, - }, - }) - } +use diffusion_rs_sys::sd_image_t; + +pub fn convert_image(sd_image: &sd_image_t) -> Result { + let len = (sd_image.width * sd_image.height * sd_image.channel) as usize; + let raw_pixels = unsafe { slice::from_raw_parts(sd_image.data, len) }; + let buffer = raw_pixels.to_vec(); + let buffer = + ImageBuffer::, _>::from_raw(sd_image.width as u32, sd_image.height as u32, buffer); + Ok(match buffer { + Some(buffer) => RgbImage::from(buffer), + None => return Err(SDImageError::AllocationError), + }) } -impl Drop for SdImageContainer { - fn drop(&mut self) { - unsafe { - free(self.inner.data as *mut c_void); +extern "C" fn my_log_callback(level: sd_log_level_t, text: *const c_char, _data: *mut c_void) { + unsafe { + // Convert C string to Rust &str and print it. + if !text.is_null() { + let msg = CStr::from_ptr(text).to_str().unwrap_or("Invalid UTF-8"); + print!("({:?}): {}", level, msg); } } } -impl TryFrom for RgbImage { - type Error = SDImageError; - - fn try_from(sd_image: SdImageContainer) -> Result { - let len = (sd_image.inner.width * sd_image.inner.height * sd_image.inner.channel) as usize; - let raw_pixels = unsafe { slice::from_raw_parts(sd_image.inner.data, len) }; - let buffer = raw_pixels.to_vec(); - let buffer = ImageBuffer::, _>::from_raw( - sd_image.inner.width as u32, - sd_image.inner.height as u32, - buffer, - ); - Ok(match buffer { - Some(buffer) => RgbImage::from(buffer), - None => return Err(SDImageError::AllocationError), - }) +pub fn setup_logging() { + unsafe { + sd_set_log_callback(Some(my_log_callback), std::ptr::null_mut()); } } From d10906aedac39d1bd231f1526ae0e5494e72d3f1 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Fri, 7 Feb 2025 13:36:36 -0500 Subject: [PATCH 11/33] update mut --- src/api.rs | 13 ++++++------- src/utils.rs | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/api.rs b/src/api.rs index ab1067c..9c31a90 100644 --- a/src/api.rs +++ b/src/api.rs @@ -59,7 +59,7 @@ impl ModelCtx { } pub fn txt2img( - &mut self, + &self, mut txt2img_config: Txt2ImgConfig, ) -> Result, DiffusionError> { // add loras to prompt as suffix @@ -158,6 +158,7 @@ mod tests { use crate::{model_config::ModelConfigBuilder, txt2img_config::Txt2ImgConfigBuilder}; use image::ImageReader; use std::path::PathBuf; + use std::sync::Arc; #[test] fn test_invalid_model_config() { @@ -208,7 +209,7 @@ mod tests { let sample_steps = 4; let control_strength = 0.4; - let mut ctx = ModelCtx::new( + let ctx = ModelCtx::new( ModelConfigBuilder::default() .model(PathBuf::from("./models/mistoonAnime_v30.safetensors")) .lora_model_dir(PathBuf::from("./models/loras")) @@ -218,7 +219,7 @@ mod tests { )) .weight_type(WeightType::SD_TYPE_F16) .flash_attention(true) - //.rng_type(RngFunction::STD_DEFAULT_RNG) + .rng_type(RngFunction::CUDA_RNG) .vae_decode_only(true) .schedule(Schedule::AYS) .build() @@ -235,7 +236,6 @@ mod tests { .sample_steps(sample_steps) .sample_method(SampleMethod::LCM) .cfg_scale(1.0) - .guidance(0.0) .height(resolution) .width(resolution) .clip_skip(ClipSkip::OneLayer) @@ -258,7 +258,6 @@ mod tests { .sample_steps(sample_steps) .sample_method(SampleMethod::LCM) .cfg_scale(1.0) - .guidance(0.0) .height(resolution) .width(resolution) .clip_skip(ClipSkip::OneLayer) @@ -302,7 +301,7 @@ mod tests { .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) .build() .unwrap(); - let mut ctx = ModelCtx::new(config).expect("Failed to build model context"); + let ctx = ModelCtx::new(config).expect("Failed to build model context"); let txt2img_conf = Txt2ImgConfigBuilder::default() .prompt("test prompt") .sample_steps(1) @@ -321,7 +320,7 @@ mod tests { .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) .build() .unwrap(); - let mut ctx = ModelCtx::new(config).expect("Failed to build model context"); + let ctx = ModelCtx::new(config).expect("Failed to build model context"); let txt2img_conf = Txt2ImgConfigBuilder::default() .prompt("multi-image prompt") .sample_steps(1) diff --git a/src/utils.rs b/src/utils.rs index a1abde9..9461e3f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -46,7 +46,7 @@ pub enum SDImageError { #[repr(i32)] #[non_exhaustive] -#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Default, Clone, Hash, PartialEq, Eq)] /// Ignore the lower X layers of CLIP network pub enum ClipSkip { /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x From 749c197b5500c76cac25925941e3facd25d4d23c Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Fri, 7 Feb 2025 13:43:28 -0500 Subject: [PATCH 12/33] delete old api --- src/old_api.rs | 574 ------------------------------------------------- 1 file changed, 574 deletions(-) delete mode 100644 src/old_api.rs diff --git a/src/old_api.rs b/src/old_api.rs deleted file mode 100644 index d6f4fbe..0000000 --- a/src/old_api.rs +++ /dev/null @@ -1,574 +0,0 @@ -use std::ffi::c_char; -use std::ffi::c_void; -use std::ffi::CString; -use std::path::Path; -use std::path::PathBuf; -use std::ptr::null; -use std::slice; - -use derive_builder::Builder; -use diffusion_rs_sys::free_upscaler_ctx; -use diffusion_rs_sys::new_upscaler_ctx; -use diffusion_rs_sys::sd_image_t; -use diffusion_rs_sys::upscaler_ctx_t; -use libc::free; -use thiserror::Error; - -use diffusion_rs_sys::free_sd_ctx; -use diffusion_rs_sys::new_sd_ctx; -use diffusion_rs_sys::sd_ctx_t; -use diffusion_rs_sys::stbi_write_png_custom; - -/// Specify the range function -pub use diffusion_rs_sys::rng_type_t as RngFunction; - -/// Sampling methods -pub use diffusion_rs_sys::sample_method_t as SampleMethod; - -/// Denoiser sigma schedule -pub use diffusion_rs_sys::schedule_t as Schedule; - -/// Weight type -pub use diffusion_rs_sys::sd_type_t as WeightType; - -#[non_exhaustive] -#[derive(Error, Debug)] -/// Error that can occurs while forwarding models -pub enum DiffusionError { - #[error("The underling stablediffusion.cpp function returned NULL")] - Forward, - #[error("The underling stbi_write_image function returned 0 while saving image {0}/{1})")] - StoreImages(usize, i32), - #[error("The underling upsclaer model returned a NULL image")] - Upscaler, -} - -#[repr(i32)] -#[non_exhaustive] -#[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)] -/// Ignore the lower X layers of CLIP network -pub enum ClipSkip { - /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x - #[default] - Unspecified = 0, - None = 1, - OneLayer = 2, -} - -#[derive(Builder, Debug, Clone)] -#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))] - -/// Config struct common to all diffusion methods -pub struct Config { - /// Number of threads to use during computation (default: 0). - /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores. - #[builder(default = "num_cpus::get_physical() as i32", setter(custom))] - n_threads: i32, - - /// Path to full model - #[builder(default = "Default::default()")] - model: CLibPath, - - /// Path to the standalone diffusion model - #[builder(default = "Default::default()")] - diffusion_model: CLibPath, - - /// path to the clip-l text encoder - #[builder(default = "Default::default()")] - clip_l: CLibPath, - - /// path to the clip-g text encoder - #[builder(default = "Default::default()")] - clip_g: CLibPath, - - /// Path to the t5xxl text encoder - #[builder(default = "Default::default()")] - t5xxl: CLibPath, - - /// Path to vae - #[builder(default = "Default::default()")] - vae: CLibPath, - - /// Path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) - #[builder(default = "Default::default()")] - taesd: CLibPath, - - /// Path to control net model - #[builder(default = "Default::default()")] - control_net: CLibPath, - - /// Path to embeddings - #[builder(default = "Default::default()")] - embeddings: CLibPath, - - /// Path to PHOTOMAKER stacked id embeddings - #[builder(default = "Default::default()")] - stacked_id_embd: CLibPath, - - /// Path to PHOTOMAKER input id images dir - #[builder(default = "Default::default()")] - input_id_images: CLibPath, - - /// Normalize PHOTOMAKER input id images - #[builder(default = "false")] - normalize_input: bool, - - /// Path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now - #[builder(default = "Default::default()")] - upscale_model: Option, - - /// Run the ESRGAN upscaler this many times (default 1) - #[builder(default = "0")] - upscale_repeats: i32, - - /// Weight type. If not specified, the default is the type of the weight file - #[builder(default = "WeightType::SD_TYPE_COUNT")] - weight_type: WeightType, - - /// Lora model directory - #[builder(default = "Default::default()", setter(custom))] - lora_model: CLibPath, - - /// Path to the input image, required by img2img - #[builder(default = "Default::default()")] - init_img: CLibPath, - - /// Path to image condition, control net - #[builder(default = "Default::default()")] - control_image: CLibPath, - - /// Path to write result image to (default: ./output.png) - #[builder(default = "PathBuf::from(\"./output.png\")")] - output: PathBuf, - - /// The prompt to render - prompt: String, - - /// The negative prompt (default: "") - #[builder(default = "\"\".into()")] - negative_prompt: CLibString, - - /// Unconditional guidance scale (default: 7.0) - #[builder(default = "7.0")] - cfg_scale: f32, - - /// Guidance (default: 3.5) - #[builder(default = "3.5")] - guidance: f32, - - /// Strength for noising/unnoising (default: 0.75) - #[builder(default = "0.75")] - strength: f32, - - /// Strength for keeping input identity (default: 20%) - #[builder(default = "20.0")] - style_ratio: f32, - - /// Strength to apply Control Net (default: 0.9) - /// 1.0 corresponds to full destruction of information in init - #[builder(default = "0.9")] - control_strength: f32, - - /// Image height, in pixel space (default: 512) - #[builder(default = "512")] - height: i32, - - /// Image width, in pixel space (default: 512) - #[builder(default = "512")] - width: i32, - - /// Sampling-method (default: EULER_A) - #[builder(default = "SampleMethod::EULER_A")] - sampling_method: SampleMethod, - - /// Number of sample steps (default: 20) - #[builder(default = "20")] - steps: i32, - - /// RNG (default: CUDA) - #[builder(default = "RngFunction::CUDA_RNG")] - rng: RngFunction, - - /// RNG seed (default: 42, use random seed for < 0) - #[builder(default = "42")] - seed: i64, - - /// Number of images to generate (default: 1) - #[builder(default = "1")] - batch_count: i32, - - /// Denoiser sigma schedule (default: DEFAULT) - #[builder(default = "Schedule::DEFAULT")] - schedule: Schedule, - - /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) - /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x - #[builder(default = "ClipSkip::Unspecified")] - clip_skip: ClipSkip, - - /// Process vae in tiles to reduce memory usage (default: false) - #[builder(default = "false")] - vae_tiling: bool, - - /// free memory of params immediately after forward (default: true) - #[builder(default = "true")] - free_params_immediately: bool, - - /// Keep vae in cpu (for low vram) (default: false) - #[builder(default = "false")] - vae_on_cpu: bool, - - /// keep clip in cpu (for low vram) (default: false) - #[builder(default = "false")] - clip_on_cpu: bool, - - /// Keep controlnet in cpu (for low vram) (default: false) - #[builder(default = "false")] - control_net_cpu: bool, - - /// Apply canny preprocessor (edge detection) (default: false) - #[builder(default = "false")] - canny: bool, - - /// Suffix that needs to be added to prompt (e.g. lora model) - #[builder(default = "None", private)] - prompt_suffix: Option, - - /// Use flash attention in the diffusion model (for low vram). - /// Might lower quality, since it implies converting k and v to f16. - /// This might crash if it is not supported by the backend. - #[builder(default = "false")] - flash_attention: bool, - - /// skip layer guidance (SLG) scale, only for DiT models: (default: 0) - /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium - #[builder(default = "0.")] - slg_scale: f32, - - /// Layers to skip for SLG steps: (default: [7,8,9]) - #[builder(default = "vec![7, 8, 9]")] - skip_layer: Vec, - - /// SLG enabling point: (default: 0.01) - #[builder(default = "0.01")] - skip_layer_start: f32, - - /// SLG disabling point: (default: 0.2) - #[builder(default = "0.2")] - skip_layer_end: f32, -} - -impl ConfigBuilder { - /// add Lora model and clip strength to the prompt suffix - /// e.g. "" - pub fn lora_model(&mut self, lora_model: &Path, clip_strength: f32) -> &mut Self { - let folder = lora_model.parent().unwrap(); - let file_name = lora_model.file_stem().unwrap().to_str().unwrap().to_owned(); - self.prompt_suffix(format!("")); - self.lora_model = Some(folder.into()); - self - } - - pub fn n_threads(&mut self, value: i32) -> &mut Self { - self.n_threads = if value > 0 { - Some(value) - } else { - Some(num_cpus::get_physical() as i32) - }; - self - } - - fn validate(&self) -> Result<(), ConfigBuilderError> { - self.validate_model()?; - self.validate_output_dir() - } - - fn validate_model(&self) -> Result<(), ConfigBuilderError> { - self.model - .as_ref() - .or(self.diffusion_model.as_ref()) - .map(|_| ()) - .ok_or(ConfigBuilderError::UninitializedField( - "Model OR DiffusionModel must be valorized", - )) - } - - fn validate_output_dir(&self) -> Result<(), ConfigBuilderError> { - let is_dir = self.output.as_ref().is_some_and(|val| val.is_dir()); - let multiple_items = self.batch_count.as_ref().is_some_and(|val| *val > 1); - if is_dir == multiple_items { - Ok(()) - } else { - Err(ConfigBuilderError::ValidationError( - "When batch_count > 0, ouput should point to folder and viceversa".to_owned(), - )) - } - } -} - -impl Config { - unsafe fn build_sd_ctx(&self, vae_decode_only: bool) -> *mut sd_ctx_t { - new_sd_ctx( - self.model.as_ptr(), - self.clip_l.as_ptr(), - self.clip_g.as_ptr(), - self.t5xxl.as_ptr(), - self.diffusion_model.as_ptr(), - self.vae.as_ptr(), - self.taesd.as_ptr(), - self.control_net.as_ptr(), - self.lora_model.as_ptr(), - self.embeddings.as_ptr(), - self.stacked_id_embd.as_ptr(), - vae_decode_only, - self.vae_tiling, - self.free_params_immediately, - self.n_threads, - self.weight_type, - self.rng, - self.schedule, - self.clip_on_cpu, - self.control_net_cpu, - self.vae_on_cpu, - self.flash_attention, - ) - } - - unsafe fn upscaler_ctx(&self) -> Option<*mut upscaler_ctx_t> { - if self.upscale_model.is_none() || self.upscale_repeats == 0 { - None - } else { - let upscaler = new_upscaler_ctx( - self.upscale_model.as_ref().unwrap().as_ptr(), - self.n_threads, - ); - Some(upscaler) - } - } -} - -#[derive(Debug, Clone, Default)] -struct CLibString(CString); - -impl CLibString { - fn as_ptr(&self) -> *const c_char { - self.0.as_ptr() - } -} - -impl From<&str> for CLibString { - fn from(value: &str) -> Self { - Self(CString::new(value).unwrap()) - } -} - -impl From for CLibString { - fn from(value: String) -> Self { - Self(CString::new(value).unwrap()) - } -} - -#[derive(Debug, Clone, Default)] -struct CLibPath(CString); - -impl CLibPath { - fn as_ptr(&self) -> *const c_char { - self.0.as_ptr() - } -} - -impl From for CLibPath { - fn from(value: PathBuf) -> Self { - Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) - } -} - -impl From<&Path> for CLibPath { - fn from(value: &Path) -> Self { - Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) - } -} - -fn output_files(path: PathBuf, batch_size: i32) -> Vec { - if batch_size == 1 { - vec![path.into()] - } else { - (1..=batch_size) - .map(|id| path.join(format!("output_{id}.png")).into()) - .collect() - } -} - -unsafe fn upscale( - upscale_repeats: i32, - upscaler_ctx: Option<*mut upscaler_ctx_t>, - data: sd_image_t, -) -> Result { - match upscaler_ctx { - Some(upscaler_ctx) => { - let upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth - let mut current_image = data; - for _ in 0..upscale_repeats { - let upscaled_image = - diffusion_rs_sys::upscale(upscaler_ctx, current_image, upscale_factor); - - if upscaled_image.data.is_null() { - return Err(DiffusionError::Upscaler); - } - - free(current_image.data as *mut c_void); - current_image = upscaled_image; - } - Ok(current_image) - } - None => Ok(data), - } -} - -/// Generate an image with a prompt -pub fn txt2img(mut config: Config) -> Result<(), DiffusionError> { - unsafe { - let prompt: CLibString = match &config.prompt_suffix { - Some(suffix) => format!("{} {suffix}", &config.prompt), - None => config.prompt.clone(), - } - .into(); - let sd_ctx = config.build_sd_ctx(true); - let upscaler_ctx = config.upscaler_ctx(); - let res = { - let slice = diffusion_rs_sys::txt2img( - sd_ctx, - prompt.as_ptr(), - config.negative_prompt.as_ptr(), - config.clip_skip as i32, - config.cfg_scale, - config.guidance, - config.width, - config.height, - config.sampling_method, - config.steps, - config.seed, - config.batch_count, - null(), - config.control_strength, - config.style_ratio, - config.normalize_input, - config.input_id_images.as_ptr(), - config.skip_layer.as_mut_ptr(), - config.skip_layer.len(), - config.slg_scale, - config.skip_layer_start, - config.skip_layer_end, - ); - if slice.is_null() { - return Err(DiffusionError::Forward); - } - let files = output_files(config.output, config.batch_count); - for (id, (img, path)) in slice::from_raw_parts(slice, config.batch_count as usize) - .iter() - .zip(files) - .enumerate() - { - match upscale(config.upscale_repeats, upscaler_ctx, *img) { - Ok(img) => { - let status = stbi_write_png_custom( - path.as_ptr(), - img.width as i32, - img.height as i32, - img.channel as i32, - img.data as *const c_void, - 0, - ); - if status == 0 { - return Err(DiffusionError::StoreImages(id, config.batch_count)); - } - } - Err(err) => { - return Err(err); - } - } - } - - //Clean-up slice section - free(slice as *mut c_void); - Ok(()) - }; - - //Clean-up CTX section - free_sd_ctx(sd_ctx); - if let Some(upscaler_ctx) = upscaler_ctx { - free_upscaler_ctx(upscaler_ctx); - } - res - } -} - -#[cfg(test)] -mod tests { - use std::path::PathBuf; - - use crate::{api::ConfigBuilderError, util::download_file_hf_hub}; - - use super::{txt2img, ConfigBuilder}; - - #[test] - fn test_required_args_txt2img() { - assert!(ConfigBuilder::default().build().is_err()); - assert!(ConfigBuilder::default() - .model(PathBuf::from("./test.ckpt")) - .build() - .is_err()); - - assert!(ConfigBuilder::default() - .prompt("a lovely cat driving a sport car") - .build() - .is_err()); - - assert!(matches!( - ConfigBuilder::default() - .model(PathBuf::from("./test.ckpt")) - .prompt("a lovely cat driving a sport car") - .batch_count(10) - .build(), - Err(ConfigBuilderError::ValidationError(_)) - )); - - ConfigBuilder::default() - .model(PathBuf::from("./test.ckpt")) - .prompt("a lovely cat driving a sport car") - .build() - .unwrap(); - - ConfigBuilder::default() - .model(PathBuf::from("./test.ckpt")) - .prompt("a lovely duck drinking water from a bottle") - .batch_count(2) - .output(PathBuf::from("./")) - .build() - .unwrap(); - } - - #[ignore] - #[test] - fn test_txt2img() { - let model_path = - download_file_hf_hub("CompVis/stable-diffusion-v-1-4-original", "sd-v1-4.ckpt") - .unwrap(); - - let upscaler_path = download_file_hf_hub( - "ximso/RealESRGAN_x4plus_anime_6B", - "RealESRGAN_x4plus_anime_6B.pth", - ) - .unwrap(); - let config = ConfigBuilder::default() - .model(model_path) - .prompt("a lovely duck drinking water from a bottle") - .output(PathBuf::from("./output_1.png")) - .upscale_model(upscaler_path) - .upscale_repeats(1) - .batch_count(1) - .build() - .unwrap(); - txt2img(config).unwrap(); - } -} From 02336e982868e64e0239f3f7c2cfe61c41f2c2ed Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Fri, 7 Feb 2025 13:45:52 -0500 Subject: [PATCH 13/33] remove warnings --- src/utils.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/utils.rs b/src/utils.rs index 9461e3f..41891bf 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,10 +1,8 @@ -use diffusion_rs_sys::sd_get_system_info; use diffusion_rs_sys::sd_log_level_t; use diffusion_rs_sys::sd_set_log_callback; use image::ImageBuffer; use image::Rgb; use image::RgbImage; -use libc::free; use std::ffi::c_char; use std::ffi::c_void; use std::ffi::CStr; From a976c0d3e4d2be55f6faa08c0a458358c55bd6a1 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Mon, 10 Feb 2025 15:47:46 -0500 Subject: [PATCH 14/33] remove lifetimes and allow for multithreading (unsafe) --- src/api.rs | 142 +++++++++++++++++++----------------------- src/txt2img_config.rs | 8 +-- token.txt | 1 - 3 files changed, 69 insertions(+), 82 deletions(-) delete mode 100644 token.txt diff --git a/src/api.rs b/src/api.rs index 9c31a90..5cdfc91 100644 --- a/src/api.rs +++ b/src/api.rs @@ -16,6 +16,9 @@ pub struct ModelCtx { pub model_config: ModelConfig, } +unsafe impl Send for ModelCtx {} +unsafe impl Sync for ModelCtx {} + impl ModelCtx { pub fn new(config: ModelConfig) -> Result { setup_logging(); @@ -158,7 +161,8 @@ mod tests { use crate::{model_config::ModelConfigBuilder, txt2img_config::Txt2ImgConfigBuilder}; use image::ImageReader; use std::path::PathBuf; - use std::sync::Arc; + use std::sync::{Arc, Mutex}; + use std::thread; #[test] fn test_invalid_model_config() { @@ -198,17 +202,6 @@ mod tests { #[test] fn test_txt2img_success() { - let resolution: i32 = 384; - - let control_image1 = ImageReader::open("./images/canny-384x.jpg") - .expect("Failed to open image") - .decode() - .expect("Failed to decode image") - .into_rgb8(); - - let sample_steps = 4; - let control_strength = 0.4; - let ctx = ModelCtx::new( ModelConfigBuilder::default() .model(PathBuf::from("./models/mistoonAnime_v30.safetensors")) @@ -222,76 +215,71 @@ mod tests { .rng_type(RngFunction::CUDA_RNG) .vae_decode_only(true) .schedule(Schedule::AYS) + .n_threads(-1) .build() .expect("Failed to build model config"), ) .expect("Failed to build model context"); - let result = ctx - .txt2img(Txt2ImgConfigBuilder::default() - .prompt("masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk") - .control_cond(&control_image1) - .control_strength(control_strength) - .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) - .sample_steps(sample_steps) - .sample_method(SampleMethod::LCM) - .cfg_scale(1.0) - .height(resolution) - .width(resolution) - .clip_skip(ClipSkip::OneLayer) - .batch_count(1) - .build() - .expect("Failed to build txt2img config 1")) - .expect("Failed to generate image 1"); - - result.iter().enumerate().for_each(|(i, img)| { - img.save(format!("./images/test_1_{}x_{}.png", resolution, i)) - .unwrap(); - }); - - let result2 = ctx - .txt2img(Txt2ImgConfigBuilder::default() - .prompt("masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy") - .control_cond(&control_image1) - .control_strength(control_strength) - .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) - .sample_steps(sample_steps) - .sample_method(SampleMethod::LCM) - .cfg_scale(1.0) - .height(resolution) - .width(resolution) - .clip_skip(ClipSkip::OneLayer) - .batch_count(1) - .build() - .expect("Failed to build txt2img config 2")) - .expect("Failed to generate image 2"); - - result2.iter().enumerate().for_each(|(i, img)| { - img.save(format!("./images/test_2_{}x_{}.png", resolution, i)) - .unwrap(); - }); - - // let result3 = ctx - // .txt2img(Txt2ImgConfigBuilder::default() - // .prompt("masterpiece, best quality, absurdres, 1girl, medium hair, brown hair, green eyes, dark skin, dark green sweater, cat ears, nyan, sexy") - // .control_cond(control_image3) - // .control_strength(control_strength) - // .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) - // .sample_steps(sample_steps) - // .sample_method(SampleMethod::LCM) - // .cfg_scale(1.0) - // .height(resolution) - // .width(resolution) - // .clip_skip(ClipSkip::OneLayer) - // .batch_count(1) - // .build() - // .expect("Failed to build txt2img config 1")) - // .expect("Failed to generate image 1"); - - // result3.iter().enumerate().for_each(|(i, img)| { - // img.save(format!("./images/test_3_{}x_{}.png", resolution, i)) - // .unwrap(); - // }); + let resolution: i32 = 384; + let sample_steps = 1; + let control_strength = 0.4; + let control_image = ImageReader::open("./images/canny-384x.jpg") + .expect("Failed to open image") + .decode() + .expect("Failed to decode image") + .into_rgb8(); + + let prompts = vec![ + "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", + "masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", + "masterpiece, best quality, absurdres, 1girl, medium hair, brown hair, green eyes, dark skin, dark green sweater, cat ears, nyan, sexy" + ]; + + let ctx = Arc::new(Mutex::new(ctx)); + + let mut handles = vec![]; + + for (index, prompt) in prompts.into_iter().enumerate() { + let txt2img_config = Txt2ImgConfigBuilder::default() + .prompt(prompt) + .control_cond(control_image.clone()) + .control_strength(control_strength) + .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) + .sample_steps(sample_steps) + .sample_method(SampleMethod::LCM) + .cfg_scale(1.0) + .height(resolution) + .width(resolution) + .clip_skip(ClipSkip::OneLayer) + .batch_count(1) + .build() + .expect("Failed to build txt2img config 1"); + + let ctx = Arc::clone(&ctx); + + let handle = thread::spawn(move || { + let result = ctx + .lock() + .unwrap() + .txt2img(txt2img_config) + .expect("Failed to generate image 1"); + + result.iter().enumerate().for_each(|(batch, img)| { + img.save(format!( + "./images/test_#{}_{}x_{}.png", + index, resolution, batch + )) + .unwrap(); + }); + }); + + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } } #[test] diff --git a/src/txt2img_config.rs b/src/txt2img_config.rs index 857e975..8c00f7d 100644 --- a/src/txt2img_config.rs +++ b/src/txt2img_config.rs @@ -5,7 +5,7 @@ use image::RgbImage; #[derive(Builder, Debug, Clone)] #[builder(setter(into), build_fn(validate = "Self::validate"))] /// txt2img config -pub struct Txt2ImgConfig<'a> { +pub struct Txt2ImgConfig { /// Prompt to generate image from pub prompt: String, @@ -26,7 +26,7 @@ pub struct Txt2ImgConfig<'a> { #[builder(default = "7.0")] pub cfg_scale: f32, - /// Guidance (default: 3.5) + /// Guidance (default: 3.5) for Flux/DiT models #[builder(default = "3.5")] pub guidance: f32, @@ -55,7 +55,7 @@ pub struct Txt2ImgConfig<'a> { pub batch_count: i32, #[builder(setter(strip_option), default)] - pub control_cond: Option<&'a RgbImage>, + pub control_cond: Option, /// Strength to apply Control Net (default: 0.9) /// 1.0 corresponds to full destruction of information in init @@ -92,7 +92,7 @@ pub struct Txt2ImgConfig<'a> { pub skip_layer_end: f32, } -impl<'a> Txt2ImgConfigBuilder<'a> { +impl Txt2ImgConfigBuilder { fn validate(&self) -> Result<(), Txt2ImgConfigBuilderError> { self.validate_prompt() } diff --git a/token.txt b/token.txt deleted file mode 100644 index 5ce8df9..0000000 --- a/token.txt +++ /dev/null @@ -1 +0,0 @@ -Your hf-hub token \ No newline at end of file From e1b81234f66ca6c202e22005bff71bf690741919 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Mon, 10 Feb 2025 20:55:15 -0500 Subject: [PATCH 15/33] arc? i suppose so --- src/api.rs | 24 ++++++++++++------------ src/txt2img_config.rs | 8 +++++--- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/api.rs b/src/api.rs index 5cdfc91..71cd579 100644 --- a/src/api.rs +++ b/src/api.rs @@ -61,10 +61,7 @@ impl ModelCtx { }) } - pub fn txt2img( - &self, - mut txt2img_config: Txt2ImgConfig, - ) -> Result, DiffusionError> { + pub fn txt2img(&self, txt2img_config: Txt2ImgConfig) -> Result, DiffusionError> { // add loras to prompt as suffix let prompt: CLibString = { let mut prompt = txt2img_config.prompt.clone(); @@ -113,7 +110,7 @@ impl ModelCtx { txt2img_config.style_strength, txt2img_config.normalize_input, txt2img_config.input_id_images.as_ptr(), - txt2img_config.skip_layer.as_mut_ptr(), + txt2img_config.skip_layer.as_ptr() as *mut i32, txt2img_config.skip_layer.len(), txt2img_config.slg_scale, txt2img_config.skip_layer_start, @@ -161,6 +158,7 @@ mod tests { use crate::{model_config::ModelConfigBuilder, txt2img_config::Txt2ImgConfigBuilder}; use image::ImageReader; use std::path::PathBuf; + use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::thread; @@ -222,13 +220,15 @@ mod tests { .expect("Failed to build model context"); let resolution: i32 = 384; - let sample_steps = 1; + let sample_steps = 6; let control_strength = 0.4; - let control_image = ImageReader::open("./images/canny-384x.jpg") - .expect("Failed to open image") - .decode() - .expect("Failed to decode image") - .into_rgb8(); + let control_image = Arc::new( + ImageReader::open("./images/canny-384x.jpg") + .expect("Failed to open image") + .decode() + .expect("Failed to decode image") + .into_rgb8(), + ); let prompts = vec![ "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", @@ -243,7 +243,7 @@ mod tests { for (index, prompt) in prompts.into_iter().enumerate() { let txt2img_config = Txt2ImgConfigBuilder::default() .prompt(prompt) - .control_cond(control_image.clone()) + .control_cond(Arc::clone(&control_image)) .control_strength(control_strength) .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) .sample_steps(sample_steps) diff --git a/src/txt2img_config.rs b/src/txt2img_config.rs index 8c00f7d..235bc7f 100644 --- a/src/txt2img_config.rs +++ b/src/txt2img_config.rs @@ -1,9 +1,11 @@ +use std::sync::Arc; + use crate::utils::{CLibPath, CLibString, ClipSkip, SampleMethod}; use derive_builder::Builder; use image::RgbImage; #[derive(Builder, Debug, Clone)] -#[builder(setter(into), build_fn(validate = "Self::validate"))] +#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))] /// txt2img config pub struct Txt2ImgConfig { /// Prompt to generate image from @@ -54,8 +56,8 @@ pub struct Txt2ImgConfig { #[builder(default = "1")] pub batch_count: i32, - #[builder(setter(strip_option), default)] - pub control_cond: Option, + #[builder(default = "None")] + pub control_cond: Option>, /// Strength to apply Control Net (default: 0.9) /// 1.0 corresponds to full destruction of information in init From eac613cb76edded655b873f9959b00d2945dce38 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Fri, 14 Feb 2025 16:16:30 -0500 Subject: [PATCH 16/33] add logging callbacks --- src/api.rs | 23 ++++++++++++++-------- src/model_config.rs | 13 +++++++++++++ src/utils.rs | 47 +++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 71 insertions(+), 12 deletions(-) diff --git a/src/api.rs b/src/api.rs index 71cd579..26982a3 100644 --- a/src/api.rs +++ b/src/api.rs @@ -21,7 +21,7 @@ unsafe impl Sync for ModelCtx {} impl ModelCtx { pub fn new(config: ModelConfig) -> Result { - setup_logging(); + setup_logging(config.log_callback, config.progress_callback); let raw_ctx = unsafe { let ptr = new_sd_ctx( @@ -87,7 +87,10 @@ impl ModelCtx { } } } - None => null(), + None => { + println!("Control net conditioning image is null, setting control image to null"); + null() + } }; //run text to image @@ -158,7 +161,6 @@ mod tests { use crate::{model_config::ModelConfigBuilder, txt2img_config::Txt2ImgConfigBuilder}; use image::ImageReader; use std::path::PathBuf; - use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::thread; @@ -206,7 +208,7 @@ mod tests { .lora_model_dir(PathBuf::from("./models/loras")) .taesd(PathBuf::from("./models/taesd1.safetensors")) .control_net(PathBuf::from( - "./models/controlnet/control_canny-fp16.safetensors", + "./models/controlnet/control_v11f1e_sd15_tile_fp16.safetensors", )) .weight_type(WeightType::SD_TYPE_F16) .flash_attention(true) @@ -221,19 +223,24 @@ mod tests { let resolution: i32 = 384; let sample_steps = 6; - let control_strength = 0.4; + let control_strength = 0.6; let control_image = Arc::new( - ImageReader::open("./images/canny-384x.jpg") + ImageReader::open("./images/samusweapon_d.png") .expect("Failed to open image") .decode() .expect("Failed to decode image") + .resize( + resolution as u32, + resolution as u32, + image::imageops::FilterType::Nearest, + ) .into_rgb8(), ); let prompts = vec![ "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", - "masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", - "masterpiece, best quality, absurdres, 1girl, medium hair, brown hair, green eyes, dark skin, dark green sweater, cat ears, nyan, sexy" + //"masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", + //"masterpiece, best quality, absurdres, 1girl, medium hair, brown hair, green eyes, dark skin, dark green sweater, cat ears, nyan, sexy" ]; let ctx = Arc::new(Mutex::new(ctx)); diff --git a/src/model_config.rs b/src/model_config.rs index b410095..2ffb143 100644 --- a/src/model_config.rs +++ b/src/model_config.rs @@ -1,4 +1,7 @@ +use std::ffi::{c_char, c_void}; + use derive_builder::Builder; +use diffusion_rs_sys::sd_log_level_t; use crate::utils::{CLibPath, RngFunction, Schedule, WeightType}; @@ -102,6 +105,16 @@ pub struct ModelConfig { /// (default: false) #[builder(default = "false")] pub flash_attention: bool, + + /// set log callback function for cpp logs (default: None) + #[builder(default = "None")] + pub log_callback: + Option, + + /// set log callback function for progress logs (default: None) + #[builder(default = "None")] + pub progress_callback: + Option, } impl ModelConfigBuilder { diff --git a/src/utils.rs b/src/utils.rs index 41891bf..44558b0 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,6 @@ use diffusion_rs_sys::sd_log_level_t; use diffusion_rs_sys::sd_set_log_callback; +use diffusion_rs_sys::sd_set_progress_callback; use image::ImageBuffer; use image::Rgb; use image::RgbImage; @@ -7,7 +8,6 @@ use std::ffi::c_char; use std::ffi::c_void; use std::ffi::CStr; use std::ffi::CString; - use std::path::Path; use std::path::PathBuf; use std::slice; @@ -122,7 +122,7 @@ pub fn convert_image(sd_image: &sd_image_t) -> Result { }) } -extern "C" fn my_log_callback(level: sd_log_level_t, text: *const c_char, _data: *mut c_void) { +extern "C" fn default_log_callback(level: sd_log_level_t, text: *const c_char, _data: *mut c_void) { unsafe { // Convert C string to Rust &str and print it. if !text.is_null() { @@ -132,8 +132,47 @@ extern "C" fn my_log_callback(level: sd_log_level_t, text: *const c_char, _data: } } -pub fn setup_logging() { +// use std::sync::LazyLock; + +// static BAR: LazyLock> = LazyLock::new(|| { +// Mutex::new(ProgressBar::no_length().with_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})") +// .unwrap() +// .progress_chars("#>-"))) +// }); + +// /// This is your C callback that gets called with current progress. +// extern "C" fn default_progress_callback(step: c_int, steps: c_int, time: f32, _data: *mut c_void) { +// // Update the global progress bar if it's been initialized. +// let mut bar = BAR.lock().unwrap(); + +// if bar.is_finished() { +// *bar = ProgressBar::no_length().with_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})") +// .unwrap() +// .progress_chars("#>-")); +// } else { +// if steps == step { +// bar.finish_with_message("Done"); +// } +// bar.set_length(steps as u64); +// bar.set_position(step as u64); +// bar.set_message(format!("Elapsed: {:.2} s", time)); +// } +// } + +pub fn setup_logging( + log_callback: Option< + extern "C" fn(level: sd_log_level_t, text: *const c_char, _data: *mut c_void), + >, + progress_callback: Option, +) { unsafe { - sd_set_log_callback(Some(my_log_callback), std::ptr::null_mut()); + match log_callback { + Some(callback) => sd_set_log_callback(Some(callback), std::ptr::null_mut()), + None => sd_set_log_callback(Some(default_log_callback), std::ptr::null_mut()), + }; + match progress_callback { + Some(callback) => sd_set_progress_callback(Some(callback), std::ptr::null_mut()), + None => (), + }; } } From 3bcb12c483739c5aee16bb4de4b407d0c1276ea5 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Fri, 14 Feb 2025 17:05:38 -0500 Subject: [PATCH 17/33] expose log level --- src/utils.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/utils.rs b/src/utils.rs index 44558b0..10c5667 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -108,6 +108,8 @@ pub use diffusion_rs_sys::sd_type_t as WeightType; /// Sampling methods pub use diffusion_rs_sys::sample_method_t as SampleMethod; +use diffusion_rs_sys::sd_log_level_t as SD_LOG_LEVEL_T; + use diffusion_rs_sys::sd_image_t; pub fn convert_image(sd_image: &sd_image_t) -> Result { @@ -161,7 +163,7 @@ extern "C" fn default_log_callback(level: sd_log_level_t, text: *const c_char, _ pub fn setup_logging( log_callback: Option< - extern "C" fn(level: sd_log_level_t, text: *const c_char, _data: *mut c_void), + extern "C" fn(level: SD_LOG_LEVEL_T, text: *const c_char, _data: *mut c_void), >, progress_callback: Option, ) { From 2bbbd76a77ddab8b1850be608542943f31225746 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Fri, 14 Feb 2025 17:06:34 -0500 Subject: [PATCH 18/33] publicize it --- src/utils.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/utils.rs b/src/utils.rs index 10c5667..7009439 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -108,7 +108,8 @@ pub use diffusion_rs_sys::sd_type_t as WeightType; /// Sampling methods pub use diffusion_rs_sys::sample_method_t as SampleMethod; -use diffusion_rs_sys::sd_log_level_t as SD_LOG_LEVEL_T; +//log level +pub use diffusion_rs_sys::sd_log_level_t as SD_LOG_LEVEL_T; use diffusion_rs_sys::sd_image_t; From 684fd6044582cfc064b9925cc6caeedc08fe3755 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Fri, 14 Feb 2025 17:10:22 -0500 Subject: [PATCH 19/33] strip option --- src/model_config.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model_config.rs b/src/model_config.rs index 2ffb143..e6945ed 100644 --- a/src/model_config.rs +++ b/src/model_config.rs @@ -6,7 +6,7 @@ use diffusion_rs_sys::sd_log_level_t; use crate::utils::{CLibPath, RngFunction, Schedule, WeightType}; #[derive(Builder, Debug, Clone)] -#[builder(setter(into), build_fn(validate = "Self::validate"))] +#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))] /// Config struct common to all diffusion methods pub struct ModelConfig { /// Path to full model From da6180ba3dc84b886258a85378580db5cade975d Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Mon, 17 Feb 2025 16:51:39 -0500 Subject: [PATCH 20/33] use sd's get physical cores --- src/model_config.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/model_config.rs b/src/model_config.rs index e6945ed..77f8d68 100644 --- a/src/model_config.rs +++ b/src/model_config.rs @@ -1,7 +1,7 @@ use std::ffi::{c_char, c_void}; use derive_builder::Builder; -use diffusion_rs_sys::sd_log_level_t; +use diffusion_rs_sys::{get_num_physical_cores, sd_log_level_t}; use crate::utils::{CLibPath, RngFunction, Schedule, WeightType}; @@ -68,10 +68,7 @@ pub struct ModelConfig { /// Number of threads to use during computation (default: 0). /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores. - #[builder( - default = "std::thread::available_parallelism().map_or(1, |p| p.get() as i32)", - setter(custom) - )] + #[builder(default = "unsafe { get_num_physical_cores() }", setter(custom))] pub n_threads: i32, /// Weight type. If not specified, the default is the type of the weight file @@ -122,7 +119,7 @@ impl ModelConfigBuilder { self.n_threads = if value > 0 { Some(value) } else { - Some(std::thread::available_parallelism().map_or(1, |p| p.get() as i32)) + Some(unsafe { get_num_physical_cores() }) }; self } From 99d24b2601b4bec2a84c798254701e8d379fd147 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Mon, 17 Feb 2025 23:05:32 -0500 Subject: [PATCH 21/33] remove clibstring and clibpath, to create more abstraction --- src/api.rs | 234 +++++++++++++++++++++++++----------------- src/model_config.rs | 25 ++--- src/txt2img_config.rs | 12 +-- src/utils.rs | 71 ++----------- 4 files changed, 170 insertions(+), 172 deletions(-) diff --git a/src/api.rs b/src/api.rs index 26982a3..3918f58 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,19 +1,21 @@ +use std::ffi::{c_void, CString}; +use std::ops::{Deref, DerefMut}; +use std::ptr::null; +use std::slice; + use crate::model_config::ModelConfig; use crate::txt2img_config::Txt2ImgConfig; -use crate::utils::{convert_image, setup_logging, CLibString, DiffusionError}; -use diffusion_rs_sys::{free_sd_ctx, new_sd_ctx, sd_ctx_t, sd_image_t, txt2img}; +use crate::utils::{convert_image, pathbuf_to_c_char, setup_logging, DiffusionError}; +use diffusion_rs_sys::sd_image_t; use image::RgbImage; -use libc::{free, strlen}; -use std::ffi::c_void; -use std::ptr::null; -use std::slice; +use libc::free; pub struct ModelCtx { /// The underlying C context - raw_ctx: Option<*mut sd_ctx_t>, + ctx: *mut diffusion_rs_sys::sd_ctx_t, /// We keep the config around in case we need to refer to it - pub model_config: ModelConfig, + pub config: ModelConfig, } unsafe impl Send for ModelCtx {} @@ -23,19 +25,19 @@ impl ModelCtx { pub fn new(config: ModelConfig) -> Result { setup_logging(config.log_callback, config.progress_callback); - let raw_ctx = unsafe { - let ptr = new_sd_ctx( - config.model.as_ptr(), - config.clip_l.as_ptr(), - config.clip_g.as_ptr(), - config.t5xxl.as_ptr(), - config.diffusion_model.as_ptr(), - config.vae.as_ptr(), - config.taesd.as_ptr(), - config.control_net.as_ptr(), - config.lora_model_dir.as_ptr(), - config.embeddings_dir.as_ptr(), - config.stacked_id_embd_dir.as_ptr(), + let ctx = unsafe { + let ptr = diffusion_rs_sys::new_sd_ctx( + pathbuf_to_c_char(&config.model).as_ptr(), + pathbuf_to_c_char(&config.clip_l).as_ptr(), + pathbuf_to_c_char(&config.clip_g).as_ptr(), + pathbuf_to_c_char(&config.t5xxl).as_ptr(), + pathbuf_to_c_char(&config.diffusion_model).as_ptr(), + pathbuf_to_c_char(&config.vae).as_ptr(), + pathbuf_to_c_char(&config.taesd).as_ptr(), + pathbuf_to_c_char(&config.control_net).as_ptr(), + pathbuf_to_c_char(&config.lora_model_dir).as_ptr(), + pathbuf_to_c_char(&config.embeddings_dir).as_ptr(), + pathbuf_to_c_char(&config.stacked_id_embd_dir).as_ptr(), config.vae_decode_only, config.vae_tiling, config.free_params_immediately, @@ -51,55 +53,55 @@ impl ModelCtx { if ptr.is_null() { return Err(DiffusionError::NewContextFailure); } else { - Some(ptr) + ptr } }; - Ok(Self { - raw_ctx, - model_config: config, - }) + Ok(Self { ctx, config }) } - pub fn txt2img(&self, txt2img_config: Txt2ImgConfig) -> Result, DiffusionError> { + pub fn txt2img( + &self, + mut txt2img_config: Txt2ImgConfig, + ) -> Result, DiffusionError> { // add loras to prompt as suffix - let prompt: CLibString = { + let prompt: CString = { let mut prompt = txt2img_config.prompt.clone(); for lora in txt2img_config.lora_prompt_suffix.iter() { prompt.push_str(lora); } - prompt.into() + CString::new(prompt).expect("Failed to convert prompt to CString") }; + let negative_prompt = CString::new(txt2img_config.negative_prompt.clone()) + .expect("Failed to convert negative prompt to CString"); + //controlnet let control_image: *const sd_image_t = match txt2img_config.control_cond { - Some(image) => { - match unsafe { strlen(self.model_config.control_net.as_ptr()) as usize > 0 } { - true => &sd_image_t { - data: image.as_ptr() as *mut u8, - width: image.width(), - height: image.height(), - channel: 3, - }, - false => { - println!("Control net model is null, setting control image to null"); - null() - } + Some(image) => match self.config.control_net.is_file() { + true => &sd_image_t { + data: image.as_ptr() as *mut u8, + width: image.width(), + height: image.height(), + channel: 3, + }, + false => { + println!("Control net model is null, setting control image to null"); + null() } - } + }, None => { println!("Control net conditioning image is null, setting control image to null"); null() } }; - //run text to image - let results: *mut sd_image_t = unsafe { - txt2img( - self.raw_ctx.ok_or(DiffusionError::NoContext)?, + let results = unsafe { + diffusion_rs_sys::txt2img( + self.ctx, prompt.as_ptr(), - txt2img_config.negative_prompt.as_ptr(), - txt2img_config.clip_skip as i32, + negative_prompt.as_ptr(), + txt2img_config.clip_skip, txt2img_config.cfg_scale, txt2img_config.guidance, txt2img_config.width, @@ -112,8 +114,8 @@ impl ModelCtx { txt2img_config.control_strength, txt2img_config.style_strength, txt2img_config.normalize_input, - txt2img_config.input_id_images.as_ptr(), - txt2img_config.skip_layer.as_ptr() as *mut i32, + pathbuf_to_c_char(&txt2img_config.input_id_images).as_ptr(), + txt2img_config.skip_layer.as_mut_ptr(), txt2img_config.skip_layer.len(), txt2img_config.slg_scale, txt2img_config.skip_layer_start, @@ -145,22 +147,16 @@ impl ModelCtx { /// Automatic cleanup on drop impl Drop for ModelCtx { fn drop(&mut self) { - match self.raw_ctx { - Some(ptr) => unsafe { - free_sd_ctx(ptr); - }, - None => {} - } + unsafe { diffusion_rs_sys::free_sd_ctx(self.ctx) }; } } #[cfg(test)] mod tests { use super::*; - use crate::utils::{ClipSkip, RngFunction, SampleMethod, Schedule, WeightType}; + use crate::utils::{RngFunction, SampleMethod, Schedule, WeightType}; use crate::{model_config::ModelConfigBuilder, txt2img_config::Txt2ImgConfigBuilder}; use image::ImageReader; - use std::path::PathBuf; use std::sync::{Arc, Mutex}; use std::thread; @@ -172,9 +168,7 @@ mod tests { #[test] fn test_valid_model_config() { - let config = ModelConfigBuilder::default() - .model(PathBuf::from("./test.ckpt")) - .build(); + let config = ModelConfigBuilder::default().model("./test.ckpt").build(); assert!(config.is_ok(), "ModelConfig should succeed with model path"); } @@ -201,31 +195,87 @@ mod tests { } #[test] - fn test_txt2img_success() { - let ctx = ModelCtx::new( - ModelConfigBuilder::default() - .model(PathBuf::from("./models/mistoonAnime_v30.safetensors")) - .lora_model_dir(PathBuf::from("./models/loras")) - .taesd(PathBuf::from("./models/taesd1.safetensors")) - .control_net(PathBuf::from( - "./models/controlnet/control_v11f1e_sd15_tile_fp16.safetensors", - )) - .weight_type(WeightType::SD_TYPE_F16) - .flash_attention(true) - .rng_type(RngFunction::CUDA_RNG) - .vae_decode_only(true) - .schedule(Schedule::AYS) - .n_threads(-1) - .build() - .expect("Failed to build model config"), - ) - .expect("Failed to build model context"); + fn test_txt2img_singlethreaded_success() { + let model_config = ModelConfigBuilder::default() + .model("./models/mistoonAnime_v30.safetensors") + .lora_model_dir("./models/loras") + .taesd("./models/taesd1.safetensors") + .control_net("./models/controlnet/control_canny-fp16.safetensors") + .weight_type(WeightType::SD_TYPE_F16) + .rng_type(RngFunction::CUDA_RNG) + .schedule(Schedule::AYS) + .vae_decode_only(true) + .flash_attention(true) + .build() + .expect("Failed to build model config"); + + let ctx = ModelCtx::new(model_config).expect("Failed to build model context"); let resolution: i32 = 384; let sample_steps = 6; - let control_strength = 0.6; + let control_strength = 0.8; + let control_image = ImageReader::open("./images/canny-384x.jpg") + .expect("Failed to open image") + .decode() + .expect("Failed to decode image") + .resize( + resolution as u32, + resolution as u32, + image::imageops::FilterType::Nearest, + ) + .into_rgb8(); + + let prompt = "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk"; + + let txt2img_config = Txt2ImgConfigBuilder::default() + .prompt(prompt) + .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) + .control_cond(control_image.clone()) + .control_strength(control_strength) + .sample_steps(sample_steps) + .sample_method(SampleMethod::LCM) + .cfg_scale(1.0) + .height(resolution) + .width(resolution) + .clip_skip(2) + .batch_count(2) + .build() + .expect("Failed to build txt2img config 1"); + + let result = ctx + .txt2img(txt2img_config) + .expect("Failed to generate image 1"); + + result.iter().enumerate().for_each(|(batch, img)| { + img.save(format!("./images/test_st_{}x_{}.png", resolution, batch)) + .unwrap(); + }); + } + + #[test] + fn test_txt2img_multithreaded_success() { + let model_config = ModelConfigBuilder::default() + .model("./models/mistoonAnime_v30.safetensors") + .lora_model_dir("./models/loras") + .taesd("./models/taesd1.safetensors") + .control_net("./models/controlnet/control_canny-fp16.safetensors") + .weight_type(WeightType::SD_TYPE_F16) + .rng_type(RngFunction::CUDA_RNG) + .schedule(Schedule::AYS) + .vae_decode_only(true) + .flash_attention(true) + .build() + .expect("Failed to build model config"); + + let ctx = Arc::new(Mutex::new( + ModelCtx::new(model_config).expect("Failed to build model context"), + )); + + let resolution: i32 = 384; + let sample_steps = 3; + let control_strength = 0.9; let control_image = Arc::new( - ImageReader::open("./images/samusweapon_d.png") + ImageReader::open("./images/canny-384x.jpg") .expect("Failed to open image") .decode() .expect("Failed to decode image") @@ -239,29 +289,27 @@ mod tests { let prompts = vec![ "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", - //"masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", - //"masterpiece, best quality, absurdres, 1girl, medium hair, brown hair, green eyes, dark skin, dark green sweater, cat ears, nyan, sexy" + "masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", + "masterpiece, best quality, absurdres, 1girl, medium hair, brown hair, green eyes, dark skin, dark green sweater, cat ears, nyan, sexy" ]; - let ctx = Arc::new(Mutex::new(ctx)); - let mut handles = vec![]; for (index, prompt) in prompts.into_iter().enumerate() { let txt2img_config = Txt2ImgConfigBuilder::default() .prompt(prompt) + .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) .control_cond(Arc::clone(&control_image)) .control_strength(control_strength) - .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) .sample_steps(sample_steps) .sample_method(SampleMethod::LCM) .cfg_scale(1.0) .height(resolution) .width(resolution) - .clip_skip(ClipSkip::OneLayer) + .clip_skip(2) .batch_count(1) .build() - .expect("Failed to build txt2img config 1"); + .expect("Failed to build txt2img config"); let ctx = Arc::clone(&ctx); @@ -270,11 +318,11 @@ mod tests { .lock() .unwrap() .txt2img(txt2img_config) - .expect("Failed to generate image 1"); + .expect("Failed to generate image"); result.iter().enumerate().for_each(|(batch, img)| { img.save(format!( - "./images/test_#{}_{}x_{}.png", + "./images/test_mt_#{}_{}x_{}.png", index, resolution, batch )) .unwrap(); @@ -293,7 +341,7 @@ mod tests { fn test_txt2img_failure() { // Build a context with invalid data to force failure let config = ModelConfigBuilder::default() - .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) + .model("./mistoonAnime_v10Illustrious.safetensors") .build() .unwrap(); let ctx = ModelCtx::new(config).expect("Failed to build model context"); @@ -312,7 +360,7 @@ mod tests { #[test] fn test_multiple_images() { let config = ModelConfigBuilder::default() - .model(PathBuf::from("./mistoonAnime_v10Illustrious.safetensors")) + .model("./mistoonAnime_v10Illustrious.safetensors") .build() .unwrap(); let ctx = ModelCtx::new(config).expect("Failed to build model context"); diff --git a/src/model_config.rs b/src/model_config.rs index 77f8d68..65f81ca 100644 --- a/src/model_config.rs +++ b/src/model_config.rs @@ -1,9 +1,10 @@ use std::ffi::{c_char, c_void}; +use std::path::PathBuf; use derive_builder::Builder; use diffusion_rs_sys::{get_num_physical_cores, sd_log_level_t}; -use crate::utils::{CLibPath, RngFunction, Schedule, WeightType}; +use crate::utils::{RngFunction, Schedule, WeightType}; #[derive(Builder, Debug, Clone)] #[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))] @@ -11,47 +12,47 @@ use crate::utils::{CLibPath, RngFunction, Schedule, WeightType}; pub struct ModelConfig { /// Path to full model #[builder(default = "Default::default()")] - pub model: CLibPath, + pub model: PathBuf, /// path to the clip-l text encoder #[builder(default = "Default::default()")] - pub clip_l: CLibPath, + pub clip_l: PathBuf, /// path to the clip-g text encoder #[builder(default = "Default::default()")] - pub clip_g: CLibPath, + pub clip_g: PathBuf, /// Path to the t5xxl text encoder #[builder(default = "Default::default()")] - pub t5xxl: CLibPath, + pub t5xxl: PathBuf, /// Path to the standalone diffusion model #[builder(default = "Default::default()")] - pub diffusion_model: CLibPath, + pub diffusion_model: PathBuf, /// Path to vae #[builder(default = "Default::default()")] - pub vae: CLibPath, + pub vae: PathBuf, /// Path to taesd. Using Tiny AutoEncoder for fast decoding (lower quality) #[builder(default = "Default::default()")] - pub taesd: CLibPath, + pub taesd: PathBuf, /// Path to control net model #[builder(default = "Default::default()")] - pub control_net: CLibPath, + pub control_net: PathBuf, /// Lora models directory #[builder(default = "Default::default()")] - pub lora_model_dir: CLibPath, + pub lora_model_dir: PathBuf, /// Path to embeddings directory #[builder(default = "Default::default()")] - pub embeddings_dir: CLibPath, + pub embeddings_dir: PathBuf, /// Path to PHOTOMAKER stacked id embeddings #[builder(default = "Default::default()")] - pub stacked_id_embd_dir: CLibPath, + pub stacked_id_embd_dir: PathBuf, //TODO: Add more info here for docs /// vae decode only (default: false) diff --git a/src/txt2img_config.rs b/src/txt2img_config.rs index 235bc7f..061af26 100644 --- a/src/txt2img_config.rs +++ b/src/txt2img_config.rs @@ -1,6 +1,6 @@ -use std::sync::Arc; +use std::{path::PathBuf, sync::Arc}; -use crate::utils::{CLibPath, CLibString, ClipSkip, SampleMethod}; +use crate::utils::SampleMethod; use derive_builder::Builder; use image::RgbImage; @@ -17,12 +17,12 @@ pub struct Txt2ImgConfig { /// The negative prompt (default: "") #[builder(default = "\"\".into()")] - pub negative_prompt: CLibString, + pub negative_prompt: String, /// Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) /// <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x - #[builder(default = "ClipSkip::Unspecified")] - pub clip_skip: ClipSkip, + #[builder(default = "0")] + pub clip_skip: i32, /// Unconditional guidance scale (default: 7.0) #[builder(default = "7.0")] @@ -74,7 +74,7 @@ pub struct Txt2ImgConfig { /// Path to PHOTOMAKER input id images dir #[builder(default = "Default::default()")] - pub input_id_images: CLibPath, + pub input_id_images: PathBuf, /// Layers to skip for SLG steps: (default: [7,8,9]) #[builder(default = "vec![7, 8, 9]")] diff --git a/src/utils.rs b/src/utils.rs index 7009439..7ed9263 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -8,7 +8,6 @@ use std::ffi::c_char; use std::ffi::c_void; use std::ffi::CStr; use std::ffi::CString; -use std::path::Path; use std::path::PathBuf; use std::slice; use thiserror::Error; @@ -29,8 +28,6 @@ pub enum DiffusionError { NewContextFailure, #[error("SD image conversion error: {0}")] SDImageError(#[from] SDImageError), - // #[error("Free Params Immediately is set to true, which means that the params are freed after forward. This means that the model can only be used once")] - // FreeParamsImmediately, } #[non_exhaustive] @@ -42,60 +39,6 @@ pub enum SDImageError { DifferentLength, } -#[repr(i32)] -#[non_exhaustive] -#[derive(Debug, Default, Clone, Hash, PartialEq, Eq)] -/// Ignore the lower X layers of CLIP network -pub enum ClipSkip { - /// Will be [ClipSkip::None] for SD1.x, [ClipSkip::OneLayer] for SD2.x - #[default] - Unspecified = 0, - None = 1, - OneLayer = 2, -} - -#[derive(Debug, Clone, Default)] -pub struct CLibString(pub CString); - -impl CLibString { - pub fn as_ptr(&self) -> *const c_char { - self.0.as_ptr() - } -} - -impl From<&str> for CLibString { - fn from(value: &str) -> Self { - Self(CString::new(value).unwrap()) - } -} - -impl From for CLibString { - fn from(value: String) -> Self { - Self(CString::new(value).unwrap()) - } -} - -#[derive(Debug, Clone, Default)] -pub struct CLibPath(CString); - -impl CLibPath { - pub fn as_ptr(&self) -> *const c_char { - self.0.as_ptr() - } -} - -impl From for CLibPath { - fn from(value: PathBuf) -> Self { - Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) - } -} - -impl From<&Path> for CLibPath { - fn from(value: &Path) -> Self { - Self(CString::new(value.to_str().unwrap_or_default()).unwrap()) - } -} - /// Specify the range function pub use diffusion_rs_sys::rng_type_t as RngFunction; @@ -109,10 +52,18 @@ pub use diffusion_rs_sys::sd_type_t as WeightType; pub use diffusion_rs_sys::sample_method_t as SampleMethod; //log level -pub use diffusion_rs_sys::sd_log_level_t as SD_LOG_LEVEL_T; +pub use diffusion_rs_sys::sd_log_level_t as SdLogLevel; use diffusion_rs_sys::sd_image_t; +pub fn pathbuf_to_c_char(path: &PathBuf) -> CString { + let path_str = path + .to_str() + .expect("PathBuf contained non-UTF-8 characters"); + // Create a CString which adds a null terminator. + CString::new(path_str).expect("CString conversion failed") +} + pub fn convert_image(sd_image: &sd_image_t) -> Result { let len = (sd_image.width * sd_image.height * sd_image.channel) as usize; let raw_pixels = unsafe { slice::from_raw_parts(sd_image.data, len) }; @@ -163,9 +114,7 @@ extern "C" fn default_log_callback(level: sd_log_level_t, text: *const c_char, _ // } pub fn setup_logging( - log_callback: Option< - extern "C" fn(level: SD_LOG_LEVEL_T, text: *const c_char, _data: *mut c_void), - >, + log_callback: Option, progress_callback: Option, ) { unsafe { From 535c345048907ea363ed3a6ccbc6eac44f1d46f5 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Sat, 22 Feb 2025 13:47:22 -0500 Subject: [PATCH 22/33] finally fix control image failures and support multithreading --- Cargo.toml | 1 - src/api.rs | 97 ++++++++++++++++++++++--------------------- src/txt2img_config.rs | 4 +- sys/Cargo.toml | 1 - sys/build.rs | 5 --- 5 files changed, 51 insertions(+), 57 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e74009b..d8d4572 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,4 +34,3 @@ hipblas = ["diffusion-rs-sys/hipblas"] metal = ["diffusion-rs-sys/metal"] vulkan = ["diffusion-rs-sys/vulkan"] sycl = ["diffusion-rs-sys/sycl"] -flashattn = ["diffusion-rs-sys/flashattn"] diff --git a/src/api.rs b/src/api.rs index 3918f58..40276b0 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,5 +1,4 @@ use std::ffi::{c_void, CString}; -use std::ops::{Deref, DerefMut}; use std::ptr::null; use std::slice; @@ -62,7 +61,7 @@ impl ModelCtx { pub fn txt2img( &self, - mut txt2img_config: Txt2ImgConfig, + txt2img_config: &mut Txt2ImgConfig, ) -> Result, DiffusionError> { // add loras to prompt as suffix let prompt: CString = { @@ -77,19 +76,20 @@ impl ModelCtx { .expect("Failed to convert negative prompt to CString"); //controlnet - let control_image: *const sd_image_t = match txt2img_config.control_cond { - Some(image) => match self.config.control_net.is_file() { - true => &sd_image_t { - data: image.as_ptr() as *mut u8, - width: image.width(), - height: image.height(), - channel: 3, - }, - false => { + let control_image: *const sd_image_t = match txt2img_config.control_cond.as_mut() { + Some(image) => { + if self.config.control_net.is_file() { + &sd_image_t { + data: image.as_mut_ptr(), + width: image.width(), + height: image.height(), + channel: 3, + } + } else { println!("Control net model is null, setting control image to null"); null() } - }, + } None => { println!("Control net conditioning image is null, setting control image to null"); null() @@ -201,7 +201,7 @@ mod tests { .lora_model_dir("./models/loras") .taesd("./models/taesd1.safetensors") .control_net("./models/controlnet/control_canny-fp16.safetensors") - .weight_type(WeightType::SD_TYPE_F16) + .weight_type(WeightType::SD_TYPE_Q4_1) .rng_type(RngFunction::CUDA_RNG) .schedule(Schedule::AYS) .vae_decode_only(true) @@ -212,8 +212,8 @@ mod tests { let ctx = ModelCtx::new(model_config).expect("Failed to build model context"); let resolution: i32 = 384; - let sample_steps = 6; - let control_strength = 0.8; + let sample_steps = 3; + let control_strength = 0.9; let control_image = ImageReader::open("./images/canny-384x.jpg") .expect("Failed to open image") .decode() @@ -227,10 +227,10 @@ mod tests { let prompt = "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk"; - let txt2img_config = Txt2ImgConfigBuilder::default() + let mut txt2img_config = Txt2ImgConfigBuilder::default() .prompt(prompt) .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) - .control_cond(control_image.clone()) + .control_cond(control_image) .control_strength(control_strength) .sample_steps(sample_steps) .sample_method(SampleMethod::LCM) @@ -243,7 +243,7 @@ mod tests { .expect("Failed to build txt2img config 1"); let result = ctx - .txt2img(txt2img_config) + .txt2img(&mut txt2img_config) .expect("Failed to generate image 1"); result.iter().enumerate().for_each(|(batch, img)| { @@ -263,7 +263,7 @@ mod tests { .rng_type(RngFunction::CUDA_RNG) .schedule(Schedule::AYS) .vae_decode_only(true) - .flash_attention(true) + .flash_attention(false) .build() .expect("Failed to build model config"); @@ -273,19 +273,17 @@ mod tests { let resolution: i32 = 384; let sample_steps = 3; - let control_strength = 0.9; - let control_image = Arc::new( - ImageReader::open("./images/canny-384x.jpg") - .expect("Failed to open image") - .decode() - .expect("Failed to decode image") - .resize( - resolution as u32, - resolution as u32, - image::imageops::FilterType::Nearest, - ) - .into_rgb8(), - ); + let control_strength = 0.8; + let control_image = ImageReader::open("./images/canny-384x.jpg") + .expect("Failed to open image") + .decode() + .expect("Failed to decode image") + .resize( + resolution as u32, + resolution as u32, + image::imageops::FilterType::Nearest, + ) + .into_rgb8(); let prompts = vec![ "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", @@ -295,19 +293,22 @@ mod tests { let mut handles = vec![]; + let mut binding = Txt2ImgConfigBuilder::default(); + let txt2img_config_base = binding + .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) + .control_cond(control_image) + .control_strength(control_strength) + .sample_steps(sample_steps) + .sample_method(SampleMethod::LCM) + .cfg_scale(1.0) + .height(resolution) + .width(resolution) + .clip_skip(2) + .batch_count(1); + for (index, prompt) in prompts.into_iter().enumerate() { - let txt2img_config = Txt2ImgConfigBuilder::default() + let mut txt2img_config = txt2img_config_base .prompt(prompt) - .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) - .control_cond(Arc::clone(&control_image)) - .control_strength(control_strength) - .sample_steps(sample_steps) - .sample_method(SampleMethod::LCM) - .cfg_scale(1.0) - .height(resolution) - .width(resolution) - .clip_skip(2) - .batch_count(1) .build() .expect("Failed to build txt2img config"); @@ -317,7 +318,7 @@ mod tests { let result = ctx .lock() .unwrap() - .txt2img(txt2img_config) + .txt2img(&mut txt2img_config) .expect("Failed to generate image"); result.iter().enumerate().for_each(|(batch, img)| { @@ -345,13 +346,13 @@ mod tests { .build() .unwrap(); let ctx = ModelCtx::new(config).expect("Failed to build model context"); - let txt2img_conf = Txt2ImgConfigBuilder::default() + let mut txt2img_conf = Txt2ImgConfigBuilder::default() .prompt("test prompt") .sample_steps(1) .build() .unwrap(); // Hypothetical failure scenario - let result = ctx.txt2img(txt2img_conf); + let result = ctx.txt2img(&mut txt2img_conf); // Expect an error if calling with invalid path // This depends on your real implementation assert!(result.is_err() || result.is_ok()); @@ -364,13 +365,13 @@ mod tests { .build() .unwrap(); let ctx = ModelCtx::new(config).expect("Failed to build model context"); - let txt2img_conf = Txt2ImgConfigBuilder::default() + let mut txt2img_conf = Txt2ImgConfigBuilder::default() .prompt("multi-image prompt") .sample_steps(1) .batch_count(3) .build() .unwrap(); - let result = ctx.txt2img(txt2img_conf); + let result = ctx.txt2img(&mut txt2img_conf); assert!(result.is_ok()); if let Ok(images) = result { assert_eq!(images.len(), 3); diff --git a/src/txt2img_config.rs b/src/txt2img_config.rs index 061af26..d83a155 100644 --- a/src/txt2img_config.rs +++ b/src/txt2img_config.rs @@ -1,4 +1,4 @@ -use std::{path::PathBuf, sync::Arc}; +use std::path::PathBuf; use crate::utils::SampleMethod; use derive_builder::Builder; @@ -57,7 +57,7 @@ pub struct Txt2ImgConfig { pub batch_count: i32, #[builder(default = "None")] - pub control_cond: Option>, + pub control_cond: Option, /// Strength to apply Control Net (default: 0.9) /// 1.0 corresponds to full destruction of information in init diff --git a/sys/Cargo.toml b/sys/Cargo.toml index 1fcbec7..c9de158 100644 --- a/sys/Cargo.toml +++ b/sys/Cargo.toml @@ -36,7 +36,6 @@ hipblas = [] metal = [] vulkan = [] sycl = [] -flashattn = [] [build-dependencies] cc = "1.1.31" diff --git a/sys/build.rs b/sys/build.rs index 4b8d875..fd7b3d7 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -194,11 +194,6 @@ fn main() { config.define("SD_SYCL", "ON"); } - #[cfg(feature = "flashattn")] - { - config.define("SD_FLASH_ATTN", "ON"); - } - // Build stable-diffusion config .profile("Release") From afe65755c8b4beff25e4c8bdc7e96ea4e795b8be Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Sat, 22 Feb 2025 16:44:12 -0500 Subject: [PATCH 23/33] cloning for new --- src/api.rs | 103 ++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 95 insertions(+), 8 deletions(-) diff --git a/src/api.rs b/src/api.rs index 40276b0..24fcba0 100644 --- a/src/api.rs +++ b/src/api.rs @@ -21,7 +21,7 @@ unsafe impl Send for ModelCtx {} unsafe impl Sync for ModelCtx {} impl ModelCtx { - pub fn new(config: ModelConfig) -> Result { + pub fn new(config: &ModelConfig) -> Result { setup_logging(config.log_callback, config.progress_callback); let ctx = unsafe { @@ -56,7 +56,10 @@ impl ModelCtx { } }; - Ok(Self { ctx, config }) + Ok(Self { + ctx, + config: config.clone(), + }) } pub fn txt2img( @@ -201,7 +204,7 @@ mod tests { .lora_model_dir("./models/loras") .taesd("./models/taesd1.safetensors") .control_net("./models/controlnet/control_canny-fp16.safetensors") - .weight_type(WeightType::SD_TYPE_Q4_1) + .weight_type(WeightType::SD_TYPE_F16) .rng_type(RngFunction::CUDA_RNG) .schedule(Schedule::AYS) .vae_decode_only(true) @@ -209,10 +212,10 @@ mod tests { .build() .expect("Failed to build model config"); - let ctx = ModelCtx::new(model_config).expect("Failed to build model context"); + let ctx = ModelCtx::new(&model_config).expect("Failed to build model context"); let resolution: i32 = 384; - let sample_steps = 3; + let sample_steps = 2; let control_strength = 0.9; let control_image = ImageReader::open("./images/canny-384x.jpg") .expect("Failed to open image") @@ -268,7 +271,7 @@ mod tests { .expect("Failed to build model config"); let ctx = Arc::new(Mutex::new( - ModelCtx::new(model_config).expect("Failed to build model context"), + ModelCtx::new(&model_config).expect("Failed to build model context"), )); let resolution: i32 = 384; @@ -338,6 +341,90 @@ mod tests { } } + #[test] + fn test_txt2img_multithreaded_multimodel_success() { + let model_config = ModelConfigBuilder::default() + .model("./models/mistoonAnime_v30.safetensors") + .lora_model_dir("./models/loras") + .taesd("./models/taesd1.safetensors") + .control_net("./models/controlnet/control_canny-fp16.safetensors") + .weight_type(WeightType::SD_TYPE_F16) + .rng_type(RngFunction::CUDA_RNG) + .schedule(Schedule::AYS) + .vae_decode_only(true) + .flash_attention(false) + .build() + .expect("Failed to build model config"); + + let ctx1 = ModelCtx::new(&model_config).expect("Failed to build model context"); + let ctx2 = ModelCtx::new(&model_config).expect("Failed to build model context"); + + let models = Arc::new(vec![ctx1, ctx2]); + + let resolution: i32 = 384; + let sample_steps = 3; + let control_strength = 0.8; + let control_image = ImageReader::open("./images/canny-384x.jpg") + .expect("Failed to open image") + .decode() + .expect("Failed to decode image") + .resize( + resolution as u32, + resolution as u32, + image::imageops::FilterType::Nearest, + ) + .into_rgb8(); + + let prompts = vec![ + "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", + "masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", + ]; + + let mut handles = vec![]; + + let mut binding = Txt2ImgConfigBuilder::default(); + let txt2img_config_base = binding + .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) + .control_cond(control_image) + .control_strength(control_strength) + .sample_steps(sample_steps) + .sample_method(SampleMethod::LCM) + .cfg_scale(1.0) + .height(resolution) + .width(resolution) + .clip_skip(2) + .batch_count(1); + + for (index, prompt) in prompts.into_iter().enumerate() { + let mut txt2img_config = txt2img_config_base + .prompt(prompt) + .build() + .expect("Failed to build txt2img config"); + + let models = Arc::clone(&models); + let handle = thread::spawn(move || { + let result = models[index] + .txt2img(&mut txt2img_config) + .expect("Failed to generate image"); + + result.iter().enumerate().for_each(|(batch, img)| { + img.save(format!( + "./images/test_mt_#{}_{}x_{}.png", + index, resolution, batch + )) + .unwrap(); + }); + println!("Thread {} finished", index); + }); + + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } + } + #[test] fn test_txt2img_failure() { // Build a context with invalid data to force failure @@ -345,7 +432,7 @@ mod tests { .model("./mistoonAnime_v10Illustrious.safetensors") .build() .unwrap(); - let ctx = ModelCtx::new(config).expect("Failed to build model context"); + let ctx = ModelCtx::new(&config).expect("Failed to build model context"); let mut txt2img_conf = Txt2ImgConfigBuilder::default() .prompt("test prompt") .sample_steps(1) @@ -364,7 +451,7 @@ mod tests { .model("./mistoonAnime_v10Illustrious.safetensors") .build() .unwrap(); - let ctx = ModelCtx::new(config).expect("Failed to build model context"); + let ctx = ModelCtx::new(&config).expect("Failed to build model context"); let mut txt2img_conf = Txt2ImgConfigBuilder::default() .prompt("multi-image prompt") .sample_steps(1) From 5cad9a1252f1b5f3028a723c8774cf8a6f07dd69 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Tue, 4 Mar 2025 00:20:04 -0500 Subject: [PATCH 24/33] update to newest sd --- .gitignore | 4 +- Cargo.toml | 9 +- src/api.rs | 28 +- src/txt2img_config.rs | 6 +- sys/Cargo.toml | 1 - sys/build.rs | 23 - sys/src/bindings.rs | 1284 -------------------------------------- sys/stable-diffusion.cpp | 2 +- sys/stb_image_write.c | 6 - sys/wrapper.h | 4 +- 10 files changed, 29 insertions(+), 1338 deletions(-) delete mode 100644 sys/src/bindings.rs delete mode 100644 sys/stb_image_write.c diff --git a/.gitignore b/.gitignore index eaba1a9..39789c0 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,6 @@ bin/act *.png .idea/ models/ -images/ \ No newline at end of file +images/ + +.DS_Store \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index d8d4572..27feb48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,9 @@ [workspace] members = ["sys"] -resolver = "2" [workspace.package] -version = "0.1.6" -edition = "2021" +version = "0.1.8" +edition = "2024" license = "MIT" repository = "https://github.com/newfla/diffusion-rs" keywords = ["ai", "stable-diffusion", "flux"] @@ -22,11 +21,11 @@ documentation = "https://docs.rs/diffusion-rs" [dependencies] derive_builder = "0.20.2" -diffusion-rs-sys = { path = "sys", version = "0.1.6" } +diffusion-rs-sys = { path = "sys", version = "0.1.8" } image = "0.25.5" libc = "0.2.161" num_cpus = "1.16.0" -thiserror = "2.0.11" +thiserror = "2.0.12" [features] cuda = ["diffusion-rs-sys/cuda"] diff --git a/src/api.rs b/src/api.rs index 24fcba0..1393726 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,10 +1,10 @@ -use std::ffi::{c_void, CString}; +use std::ffi::{CString, c_void}; use std::ptr::null; use std::slice; use crate::model_config::ModelConfig; use crate::txt2img_config::Txt2ImgConfig; -use crate::utils::{convert_image, pathbuf_to_c_char, setup_logging, DiffusionError}; +use crate::utils::{DiffusionError, convert_image, pathbuf_to_c_char, setup_logging}; use diffusion_rs_sys::sd_image_t; use image::RgbImage; use libc::free; @@ -107,6 +107,7 @@ impl ModelCtx { txt2img_config.clip_skip, txt2img_config.cfg_scale, txt2img_config.guidance, + txt2img_config.eta, txt2img_config.width, txt2img_config.height, txt2img_config.sample_method, @@ -215,8 +216,8 @@ mod tests { let ctx = ModelCtx::new(&model_config).expect("Failed to build model context"); let resolution: i32 = 384; - let sample_steps = 2; - let control_strength = 0.9; + let sample_steps = 3; + let control_strength = 0.8; let control_image = ImageReader::open("./images/canny-384x.jpg") .expect("Failed to open image") .decode() @@ -232,16 +233,17 @@ mod tests { let mut txt2img_config = Txt2ImgConfigBuilder::default() .prompt(prompt) - .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) + .add_lora_model("pcm_sd15_smallcfg_2step_converted", 1.0) .control_cond(control_image) .control_strength(control_strength) .sample_steps(sample_steps) - .sample_method(SampleMethod::LCM) + .sample_method(SampleMethod::TCD) + .eta(1.0) .cfg_scale(1.0) .height(resolution) .width(resolution) .clip_skip(2) - .batch_count(2) + .batch_count(5) .build() .expect("Failed to build txt2img config 1"); @@ -291,7 +293,7 @@ mod tests { let prompts = vec![ "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", "masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", - "masterpiece, best quality, absurdres, 1girl, medium hair, brown hair, green eyes, dark skin, dark green sweater, cat ears, nyan, sexy" + "masterpiece, best quality, absurdres, 1girl, medium hair, brown hair, green eyes, dark skin, dark green sweater, cat ears, nyan, sexy", ]; let mut handles = vec![]; @@ -352,7 +354,7 @@ mod tests { .rng_type(RngFunction::CUDA_RNG) .schedule(Schedule::AYS) .vae_decode_only(true) - .flash_attention(false) + .flash_attention(true) .build() .expect("Failed to build model config"); @@ -362,8 +364,8 @@ mod tests { let models = Arc::new(vec![ctx1, ctx2]); let resolution: i32 = 384; - let sample_steps = 3; - let control_strength = 0.8; + let sample_steps = 4; + let control_strength = 0.5; let control_image = ImageReader::open("./images/canny-384x.jpg") .expect("Failed to open image") .decode() @@ -377,7 +379,7 @@ mod tests { let prompts = vec![ "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", - "masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", + //"masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", ]; let mut handles = vec![]; @@ -409,7 +411,7 @@ mod tests { result.iter().enumerate().for_each(|(batch, img)| { img.save(format!( - "./images/test_mt_#{}_{}x_{}.png", + "./images/test_mt_mm_#{}_{}x_{}.png", index, resolution, batch )) .unwrap(); diff --git a/src/txt2img_config.rs b/src/txt2img_config.rs index d83a155..f508093 100644 --- a/src/txt2img_config.rs +++ b/src/txt2img_config.rs @@ -32,6 +32,10 @@ pub struct Txt2ImgConfig { #[builder(default = "3.5")] pub guidance: f32, + /// eta in DDIM, only for DDIM and TCD: (default: 0) + #[builder(default = "0.0")] + pub eta: f32, + /// Image height, in pixel space (default: 512) #[builder(default = "512")] pub height: i32, @@ -82,7 +86,7 @@ pub struct Txt2ImgConfig { /// skip layer guidance (SLG) scale, only for DiT models: (default: 0) /// 0 means disabled, a value of 2.5 is nice for sd3.5 medium - #[builder(default = "0.")] + #[builder(default = "0.0")] pub slg_scale: f32, /// SLG enabling point: (default: 0.01) diff --git a/sys/Cargo.toml b/sys/Cargo.toml index c9de158..6c65fda 100644 --- a/sys/Cargo.toml +++ b/sys/Cargo.toml @@ -38,7 +38,6 @@ vulkan = [] sycl = [] [build-dependencies] -cc = "1.1.31" cmake = "0.1.51" bindgen = "0.71.1" fs_extra = "1.3.0" diff --git a/sys/build.rs b/sys/build.rs index fd7b3d7..7ddc6c3 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -17,12 +17,10 @@ fn main() { } println!("cargo:rerun-if-changed=wrapper.h"); - println!("cargo:rerun-if-changed=stb_image_write.c"); // Copy stable-diffusion code into the build script directory let out = PathBuf::from(env::var("OUT_DIR").unwrap()); let diffusion_root = out.join("stable-diffusion.cpp/"); - let stb_write_image_src = diffusion_root.join("thirdparty/stb_image_write.c"); if !diffusion_root.exists() { create_dir_all(&diffusion_root).unwrap(); @@ -33,16 +31,6 @@ fn main() { e ) }); - fs::copy("./stb_image_write.c", &stb_write_image_src).unwrap_or_else(|e| { - panic!( - "Failed to copy stb_image_write to {}: {}", - stb_write_image_src.display(), - e - ) - }); - - remove_default_params_stb(&diffusion_root.join("thirdparty/stb_image_write.h")) - .unwrap_or_else(|e| panic!("Failed to remove default parameters from stb: {}", e)); } // Bindgen @@ -205,11 +193,6 @@ fn main() { let destination = config.build(); - // Build stb write image - let mut builder = cc::Build::new(); - - builder.file(stb_write_image_src).compile("stbwriteimage"); - add_link_search_path(&out.join("lib")).unwrap(); add_link_search_path(&out.join("build")).unwrap(); add_link_search_path(&out).unwrap(); @@ -261,9 +244,3 @@ fn get_cpp_link_stdlib(target: &str) -> Option<&'static str> { Some("stdc++") } } - -fn remove_default_params_stb(file: &Path) -> std::io::Result<()> { - let data = fs::read_to_string(file)?; - let new_data = data.replace("const char* parameters = NULL", "const char* parameters"); - fs::write(file, new_data) -} diff --git a/sys/src/bindings.rs b/sys/src/bindings.rs deleted file mode 100644 index e7dbde9..0000000 --- a/sys/src/bindings.rs +++ /dev/null @@ -1,1284 +0,0 @@ -/* automatically generated by rust-bindgen 0.71.1 */ - -pub const __bool_true_false_are_defined: u32 = 1; -pub const false_: u32 = 0; -pub const true_: u32 = 1; -pub const _VCRT_COMPILER_PREPROCESSOR: u32 = 1; -pub const _SAL_VERSION: u32 = 20; -pub const __SAL_H_VERSION: u32 = 180000000; -pub const _USE_DECLSPECS_FOR_SAL: u32 = 0; -pub const _USE_ATTRIBUTES_FOR_SAL: u32 = 0; -pub const _CRT_PACKING: u32 = 8; -pub const _HAS_EXCEPTIONS: u32 = 1; -pub const _STL_LANG: u32 = 0; -pub const _HAS_CXX17: u32 = 0; -pub const _HAS_CXX20: u32 = 0; -pub const _HAS_CXX23: u32 = 0; -pub const _HAS_NODISCARD: u32 = 0; -pub const _ARM_WINAPI_PARTITION_DESKTOP_SDK_AVAILABLE: u32 = 1; -pub const _CRT_BUILD_DESKTOP_APP: u32 = 1; -pub const _ARGMAX: u32 = 100; -pub const _CRT_INT_MAX: u32 = 2147483647; -pub const _CRT_FUNCTIONS_REQUIRED: u32 = 1; -pub const _CRT_HAS_CXX17: u32 = 0; -pub const _CRT_HAS_C11: u32 = 1; -pub const _CRT_INTERNAL_NONSTDC_NAMES: u32 = 1; -pub const __STDC_SECURE_LIB__: u32 = 200411; -pub const __GOT_SECURE_LIB__: u32 = 200411; -pub const __STDC_WANT_SECURE_LIB__: u32 = 1; -pub const _SECURECRT_FILL_BUFFER_PATTERN: u32 = 254; -pub const _CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES: u32 = 0; -pub const _CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES_COUNT: u32 = 0; -pub const _CRT_SECURE_CPP_OVERLOAD_SECURE_NAMES: u32 = 1; -pub const _CRT_SECURE_CPP_OVERLOAD_STANDARD_NAMES_MEMORY: u32 = 0; -pub const _CRT_SECURE_CPP_OVERLOAD_SECURE_NAMES_MEMORY: u32 = 0; -pub const WCHAR_MIN: u32 = 0; -pub const WCHAR_MAX: u32 = 65535; -pub const WINT_MIN: u32 = 0; -pub const WINT_MAX: u32 = 65535; -pub const EPERM: u32 = 1; -pub const ENOENT: u32 = 2; -pub const ESRCH: u32 = 3; -pub const EINTR: u32 = 4; -pub const EIO: u32 = 5; -pub const ENXIO: u32 = 6; -pub const E2BIG: u32 = 7; -pub const ENOEXEC: u32 = 8; -pub const EBADF: u32 = 9; -pub const ECHILD: u32 = 10; -pub const EAGAIN: u32 = 11; -pub const ENOMEM: u32 = 12; -pub const EACCES: u32 = 13; -pub const EFAULT: u32 = 14; -pub const EBUSY: u32 = 16; -pub const EEXIST: u32 = 17; -pub const EXDEV: u32 = 18; -pub const ENODEV: u32 = 19; -pub const ENOTDIR: u32 = 20; -pub const EISDIR: u32 = 21; -pub const ENFILE: u32 = 23; -pub const EMFILE: u32 = 24; -pub const ENOTTY: u32 = 25; -pub const EFBIG: u32 = 27; -pub const ENOSPC: u32 = 28; -pub const ESPIPE: u32 = 29; -pub const EROFS: u32 = 30; -pub const EMLINK: u32 = 31; -pub const EPIPE: u32 = 32; -pub const EDOM: u32 = 33; -pub const EDEADLK: u32 = 36; -pub const ENAMETOOLONG: u32 = 38; -pub const ENOLCK: u32 = 39; -pub const ENOSYS: u32 = 40; -pub const ENOTEMPTY: u32 = 41; -pub const EINVAL: u32 = 22; -pub const ERANGE: u32 = 34; -pub const EILSEQ: u32 = 42; -pub const STRUNCATE: u32 = 80; -pub const EDEADLOCK: u32 = 36; -pub const EADDRINUSE: u32 = 100; -pub const EADDRNOTAVAIL: u32 = 101; -pub const EAFNOSUPPORT: u32 = 102; -pub const EALREADY: u32 = 103; -pub const EBADMSG: u32 = 104; -pub const ECANCELED: u32 = 105; -pub const ECONNABORTED: u32 = 106; -pub const ECONNREFUSED: u32 = 107; -pub const ECONNRESET: u32 = 108; -pub const EDESTADDRREQ: u32 = 109; -pub const EHOSTUNREACH: u32 = 110; -pub const EIDRM: u32 = 111; -pub const EINPROGRESS: u32 = 112; -pub const EISCONN: u32 = 113; -pub const ELOOP: u32 = 114; -pub const EMSGSIZE: u32 = 115; -pub const ENETDOWN: u32 = 116; -pub const ENETRESET: u32 = 117; -pub const ENETUNREACH: u32 = 118; -pub const ENOBUFS: u32 = 119; -pub const ENODATA: u32 = 120; -pub const ENOLINK: u32 = 121; -pub const ENOMSG: u32 = 122; -pub const ENOPROTOOPT: u32 = 123; -pub const ENOSR: u32 = 124; -pub const ENOSTR: u32 = 125; -pub const ENOTCONN: u32 = 126; -pub const ENOTRECOVERABLE: u32 = 127; -pub const ENOTSOCK: u32 = 128; -pub const ENOTSUP: u32 = 129; -pub const EOPNOTSUPP: u32 = 130; -pub const EOTHER: u32 = 131; -pub const EOVERFLOW: u32 = 132; -pub const EOWNERDEAD: u32 = 133; -pub const EPROTO: u32 = 134; -pub const EPROTONOSUPPORT: u32 = 135; -pub const EPROTOTYPE: u32 = 136; -pub const ETIME: u32 = 137; -pub const ETIMEDOUT: u32 = 138; -pub const ETXTBSY: u32 = 139; -pub const EWOULDBLOCK: u32 = 140; -pub const _NLSCMPERROR: u32 = 2147483647; -pub type va_list = *mut ::std::os::raw::c_char; -unsafe extern "C" { - pub fn __va_start(arg1: *mut *mut ::std::os::raw::c_char, ...); -} -pub type __vcrt_bool = bool; -pub type wchar_t = ::std::os::raw::c_ushort; -unsafe extern "C" { - pub fn __security_init_cookie(); -} -unsafe extern "C" { - pub fn __security_check_cookie(_StackCookie: usize); -} -unsafe extern "C" { - pub fn __report_gsfailure(_StackCookie: usize) -> !; -} -unsafe extern "C" { - pub static mut __security_cookie: usize; -} -pub type __crt_bool = bool; -unsafe extern "C" { - pub fn _invalid_parameter_noinfo(); -} -unsafe extern "C" { - pub fn _invalid_parameter_noinfo_noreturn() -> !; -} -unsafe extern "C" { - pub fn _invoke_watson( - _Expression: *const wchar_t, - _FunctionName: *const wchar_t, - _FileName: *const wchar_t, - _LineNo: ::std::os::raw::c_uint, - _Reserved: usize, - ) -> !; -} -pub type errno_t = ::std::os::raw::c_int; -pub type wint_t = ::std::os::raw::c_ushort; -pub type wctype_t = ::std::os::raw::c_ushort; -pub type __time32_t = ::std::os::raw::c_long; -pub type __time64_t = ::std::os::raw::c_longlong; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct __crt_locale_data_public { - pub _locale_pctype: *const ::std::os::raw::c_ushort, - pub _locale_mb_cur_max: ::std::os::raw::c_int, - pub _locale_lc_codepage: ::std::os::raw::c_uint, -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of __crt_locale_data_public"] - [::std::mem::size_of::<__crt_locale_data_public>() - 16usize]; - ["Alignment of __crt_locale_data_public"] - [::std::mem::align_of::<__crt_locale_data_public>() - 8usize]; - ["Offset of field: __crt_locale_data_public::_locale_pctype"] - [::std::mem::offset_of!(__crt_locale_data_public, _locale_pctype) - 0usize]; - ["Offset of field: __crt_locale_data_public::_locale_mb_cur_max"] - [::std::mem::offset_of!(__crt_locale_data_public, _locale_mb_cur_max) - 8usize]; - ["Offset of field: __crt_locale_data_public::_locale_lc_codepage"] - [::std::mem::offset_of!(__crt_locale_data_public, _locale_lc_codepage) - 12usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct __crt_locale_pointers { - pub locinfo: *mut __crt_locale_data, - pub mbcinfo: *mut __crt_multibyte_data, -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of __crt_locale_pointers"][::std::mem::size_of::<__crt_locale_pointers>() - 16usize]; - ["Alignment of __crt_locale_pointers"] - [::std::mem::align_of::<__crt_locale_pointers>() - 8usize]; - ["Offset of field: __crt_locale_pointers::locinfo"] - [::std::mem::offset_of!(__crt_locale_pointers, locinfo) - 0usize]; - ["Offset of field: __crt_locale_pointers::mbcinfo"] - [::std::mem::offset_of!(__crt_locale_pointers, mbcinfo) - 8usize]; -}; -pub type _locale_t = *mut __crt_locale_pointers; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct _Mbstatet { - pub _Wchar: ::std::os::raw::c_ulong, - pub _Byte: ::std::os::raw::c_ushort, - pub _State: ::std::os::raw::c_ushort, -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of _Mbstatet"][::std::mem::size_of::<_Mbstatet>() - 8usize]; - ["Alignment of _Mbstatet"][::std::mem::align_of::<_Mbstatet>() - 4usize]; - ["Offset of field: _Mbstatet::_Wchar"][::std::mem::offset_of!(_Mbstatet, _Wchar) - 0usize]; - ["Offset of field: _Mbstatet::_Byte"][::std::mem::offset_of!(_Mbstatet, _Byte) - 4usize]; - ["Offset of field: _Mbstatet::_State"][::std::mem::offset_of!(_Mbstatet, _State) - 6usize]; -}; -pub type mbstate_t = _Mbstatet; -pub type time_t = __time64_t; -pub type rsize_t = usize; -unsafe extern "C" { - pub fn _errno() -> *mut ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _set_errno(_Value: ::std::os::raw::c_int) -> errno_t; -} -unsafe extern "C" { - pub fn _get_errno(_Value: *mut ::std::os::raw::c_int) -> errno_t; -} -unsafe extern "C" { - pub fn __threadid() -> ::std::os::raw::c_ulong; -} -unsafe extern "C" { - pub fn __threadhandle() -> usize; -} -pub type int_least8_t = ::std::os::raw::c_schar; -pub type int_least16_t = ::std::os::raw::c_short; -pub type int_least32_t = ::std::os::raw::c_int; -pub type int_least64_t = ::std::os::raw::c_longlong; -pub type uint_least8_t = ::std::os::raw::c_uchar; -pub type uint_least16_t = ::std::os::raw::c_ushort; -pub type uint_least32_t = ::std::os::raw::c_uint; -pub type uint_least64_t = ::std::os::raw::c_ulonglong; -pub type int_fast8_t = ::std::os::raw::c_schar; -pub type int_fast16_t = ::std::os::raw::c_int; -pub type int_fast32_t = ::std::os::raw::c_int; -pub type int_fast64_t = ::std::os::raw::c_longlong; -pub type uint_fast8_t = ::std::os::raw::c_uchar; -pub type uint_fast16_t = ::std::os::raw::c_uint; -pub type uint_fast32_t = ::std::os::raw::c_uint; -pub type uint_fast64_t = ::std::os::raw::c_ulonglong; -pub type intmax_t = ::std::os::raw::c_longlong; -pub type uintmax_t = ::std::os::raw::c_ulonglong; -unsafe extern "C" { - pub fn __doserrno() -> *mut ::std::os::raw::c_ulong; -} -unsafe extern "C" { - pub fn _set_doserrno(_Value: ::std::os::raw::c_ulong) -> errno_t; -} -unsafe extern "C" { - pub fn _get_doserrno(_Value: *mut ::std::os::raw::c_ulong) -> errno_t; -} -unsafe extern "C" { - pub fn memchr( - _Buf: *const ::std::os::raw::c_void, - _Val: ::std::os::raw::c_int, - _MaxCount: ::std::os::raw::c_ulonglong, - ) -> *mut ::std::os::raw::c_void; -} -unsafe extern "C" { - pub fn memcmp( - _Buf1: *const ::std::os::raw::c_void, - _Buf2: *const ::std::os::raw::c_void, - _Size: ::std::os::raw::c_ulonglong, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn memcpy( - _Dst: *mut ::std::os::raw::c_void, - _Src: *const ::std::os::raw::c_void, - _Size: ::std::os::raw::c_ulonglong, - ) -> *mut ::std::os::raw::c_void; -} -unsafe extern "C" { - pub fn memmove( - _Dst: *mut ::std::os::raw::c_void, - _Src: *const ::std::os::raw::c_void, - _Size: ::std::os::raw::c_ulonglong, - ) -> *mut ::std::os::raw::c_void; -} -unsafe extern "C" { - pub fn memset( - _Dst: *mut ::std::os::raw::c_void, - _Val: ::std::os::raw::c_int, - _Size: ::std::os::raw::c_ulonglong, - ) -> *mut ::std::os::raw::c_void; -} -unsafe extern "C" { - pub fn strchr( - _Str: *const ::std::os::raw::c_char, - _Val: ::std::os::raw::c_int, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strrchr( - _Str: *const ::std::os::raw::c_char, - _Ch: ::std::os::raw::c_int, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strstr( - _Str: *const ::std::os::raw::c_char, - _SubStr: *const ::std::os::raw::c_char, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn wcschr( - _Str: *const ::std::os::raw::c_ushort, - _Ch: ::std::os::raw::c_ushort, - ) -> *mut ::std::os::raw::c_ushort; -} -unsafe extern "C" { - pub fn wcsrchr(_Str: *const wchar_t, _Ch: wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcsstr(_Str: *const wchar_t, _SubStr: *const wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn _memicmp( - _Buf1: *const ::std::os::raw::c_void, - _Buf2: *const ::std::os::raw::c_void, - _Size: usize, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _memicmp_l( - _Buf1: *const ::std::os::raw::c_void, - _Buf2: *const ::std::os::raw::c_void, - _Size: usize, - _Locale: _locale_t, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn memccpy( - _Dst: *mut ::std::os::raw::c_void, - _Src: *const ::std::os::raw::c_void, - _Val: ::std::os::raw::c_int, - _Size: ::std::os::raw::c_ulonglong, - ) -> *mut ::std::os::raw::c_void; -} -unsafe extern "C" { - pub fn memicmp( - _Buf1: *const ::std::os::raw::c_void, - _Buf2: *const ::std::os::raw::c_void, - _Size: usize, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn wcscat_s( - _Destination: *mut wchar_t, - _SizeInWords: rsize_t, - _Source: *const wchar_t, - ) -> errno_t; -} -unsafe extern "C" { - pub fn wcscpy_s( - _Destination: *mut wchar_t, - _SizeInWords: rsize_t, - _Source: *const wchar_t, - ) -> errno_t; -} -unsafe extern "C" { - pub fn wcsncat_s( - _Destination: *mut wchar_t, - _SizeInWords: rsize_t, - _Source: *const wchar_t, - _MaxCount: rsize_t, - ) -> errno_t; -} -unsafe extern "C" { - pub fn wcsncpy_s( - _Destination: *mut wchar_t, - _SizeInWords: rsize_t, - _Source: *const wchar_t, - _MaxCount: rsize_t, - ) -> errno_t; -} -unsafe extern "C" { - pub fn wcstok_s( - _String: *mut wchar_t, - _Delimiter: *const wchar_t, - _Context: *mut *mut wchar_t, - ) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn _wcsdup(_String: *const wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcscat(_Destination: *mut wchar_t, _Source: *const wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcscmp( - _String1: *const ::std::os::raw::c_ushort, - _String2: *const ::std::os::raw::c_ushort, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn wcscpy(_Destination: *mut wchar_t, _Source: *const wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcscspn(_String: *const wchar_t, _Control: *const wchar_t) -> usize; -} -unsafe extern "C" { - pub fn wcslen(_String: *const ::std::os::raw::c_ushort) -> ::std::os::raw::c_ulonglong; -} -unsafe extern "C" { - pub fn wcsnlen(_Source: *const wchar_t, _MaxCount: usize) -> usize; -} -unsafe extern "C" { - pub fn wcsncat( - _Destination: *mut wchar_t, - _Source: *const wchar_t, - _Count: usize, - ) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcsncmp( - _String1: *const ::std::os::raw::c_ushort, - _String2: *const ::std::os::raw::c_ushort, - _MaxCount: ::std::os::raw::c_ulonglong, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn wcsncpy( - _Destination: *mut wchar_t, - _Source: *const wchar_t, - _Count: usize, - ) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcspbrk(_String: *const wchar_t, _Control: *const wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcsspn(_String: *const wchar_t, _Control: *const wchar_t) -> usize; -} -unsafe extern "C" { - pub fn wcstok( - _String: *mut wchar_t, - _Delimiter: *const wchar_t, - _Context: *mut *mut wchar_t, - ) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn _wcserror(_ErrorNumber: ::std::os::raw::c_int) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn _wcserror_s( - _Buffer: *mut wchar_t, - _SizeInWords: usize, - _ErrorNumber: ::std::os::raw::c_int, - ) -> errno_t; -} -unsafe extern "C" { - pub fn __wcserror(_String: *const wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn __wcserror_s( - _Buffer: *mut wchar_t, - _SizeInWords: usize, - _ErrorMessage: *const wchar_t, - ) -> errno_t; -} -unsafe extern "C" { - pub fn _wcsicmp(_String1: *const wchar_t, _String2: *const wchar_t) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _wcsicmp_l( - _String1: *const wchar_t, - _String2: *const wchar_t, - _Locale: _locale_t, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _wcsnicmp( - _String1: *const wchar_t, - _String2: *const wchar_t, - _MaxCount: usize, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _wcsnicmp_l( - _String1: *const wchar_t, - _String2: *const wchar_t, - _MaxCount: usize, - _Locale: _locale_t, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _wcsnset_s( - _Destination: *mut wchar_t, - _SizeInWords: usize, - _Value: wchar_t, - _MaxCount: usize, - ) -> errno_t; -} -unsafe extern "C" { - pub fn _wcsnset(_String: *mut wchar_t, _Value: wchar_t, _MaxCount: usize) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn _wcsrev(_String: *mut wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn _wcsset_s(_Destination: *mut wchar_t, _SizeInWords: usize, _Value: wchar_t) -> errno_t; -} -unsafe extern "C" { - pub fn _wcsset(_String: *mut wchar_t, _Value: wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn _wcslwr_s(_String: *mut wchar_t, _SizeInWords: usize) -> errno_t; -} -unsafe extern "C" { - pub fn _wcslwr(_String: *mut wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn _wcslwr_s_l(_String: *mut wchar_t, _SizeInWords: usize, _Locale: _locale_t) -> errno_t; -} -unsafe extern "C" { - pub fn _wcslwr_l(_String: *mut wchar_t, _Locale: _locale_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn _wcsupr_s(_String: *mut wchar_t, _Size: usize) -> errno_t; -} -unsafe extern "C" { - pub fn _wcsupr(_String: *mut wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn _wcsupr_s_l(_String: *mut wchar_t, _Size: usize, _Locale: _locale_t) -> errno_t; -} -unsafe extern "C" { - pub fn _wcsupr_l(_String: *mut wchar_t, _Locale: _locale_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcsxfrm(_Destination: *mut wchar_t, _Source: *const wchar_t, _MaxCount: usize) -> usize; -} -unsafe extern "C" { - pub fn _wcsxfrm_l( - _Destination: *mut wchar_t, - _Source: *const wchar_t, - _MaxCount: usize, - _Locale: _locale_t, - ) -> usize; -} -unsafe extern "C" { - pub fn wcscoll(_String1: *const wchar_t, _String2: *const wchar_t) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _wcscoll_l( - _String1: *const wchar_t, - _String2: *const wchar_t, - _Locale: _locale_t, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _wcsicoll(_String1: *const wchar_t, _String2: *const wchar_t) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _wcsicoll_l( - _String1: *const wchar_t, - _String2: *const wchar_t, - _Locale: _locale_t, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _wcsncoll( - _String1: *const wchar_t, - _String2: *const wchar_t, - _MaxCount: usize, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _wcsncoll_l( - _String1: *const wchar_t, - _String2: *const wchar_t, - _MaxCount: usize, - _Locale: _locale_t, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _wcsnicoll( - _String1: *const wchar_t, - _String2: *const wchar_t, - _MaxCount: usize, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _wcsnicoll_l( - _String1: *const wchar_t, - _String2: *const wchar_t, - _MaxCount: usize, - _Locale: _locale_t, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn wcsdup(_String: *const wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcsicmp(_String1: *const wchar_t, _String2: *const wchar_t) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn wcsnicmp( - _String1: *const wchar_t, - _String2: *const wchar_t, - _MaxCount: usize, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn wcsnset(_String: *mut wchar_t, _Value: wchar_t, _MaxCount: usize) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcsrev(_String: *mut wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcsset(_String: *mut wchar_t, _Value: wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcslwr(_String: *mut wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcsupr(_String: *mut wchar_t) -> *mut wchar_t; -} -unsafe extern "C" { - pub fn wcsicoll(_String1: *const wchar_t, _String2: *const wchar_t) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn strcpy_s( - _Destination: *mut ::std::os::raw::c_char, - _SizeInBytes: rsize_t, - _Source: *const ::std::os::raw::c_char, - ) -> errno_t; -} -unsafe extern "C" { - pub fn strcat_s( - _Destination: *mut ::std::os::raw::c_char, - _SizeInBytes: rsize_t, - _Source: *const ::std::os::raw::c_char, - ) -> errno_t; -} -unsafe extern "C" { - pub fn strerror_s( - _Buffer: *mut ::std::os::raw::c_char, - _SizeInBytes: usize, - _ErrorNumber: ::std::os::raw::c_int, - ) -> errno_t; -} -unsafe extern "C" { - pub fn strncat_s( - _Destination: *mut ::std::os::raw::c_char, - _SizeInBytes: rsize_t, - _Source: *const ::std::os::raw::c_char, - _MaxCount: rsize_t, - ) -> errno_t; -} -unsafe extern "C" { - pub fn strncpy_s( - _Destination: *mut ::std::os::raw::c_char, - _SizeInBytes: rsize_t, - _Source: *const ::std::os::raw::c_char, - _MaxCount: rsize_t, - ) -> errno_t; -} -unsafe extern "C" { - pub fn strtok_s( - _String: *mut ::std::os::raw::c_char, - _Delimiter: *const ::std::os::raw::c_char, - _Context: *mut *mut ::std::os::raw::c_char, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn _memccpy( - _Dst: *mut ::std::os::raw::c_void, - _Src: *const ::std::os::raw::c_void, - _Val: ::std::os::raw::c_int, - _MaxCount: usize, - ) -> *mut ::std::os::raw::c_void; -} -unsafe extern "C" { - pub fn strcat( - _Destination: *mut ::std::os::raw::c_char, - _Source: *const ::std::os::raw::c_char, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strcmp( - _Str1: *const ::std::os::raw::c_char, - _Str2: *const ::std::os::raw::c_char, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _strcmpi( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn strcoll( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _strcoll_l( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - _Locale: _locale_t, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn strcpy( - _Destination: *mut ::std::os::raw::c_char, - _Source: *const ::std::os::raw::c_char, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strcspn( - _Str: *const ::std::os::raw::c_char, - _Control: *const ::std::os::raw::c_char, - ) -> ::std::os::raw::c_ulonglong; -} -unsafe extern "C" { - pub fn _strdup(_Source: *const ::std::os::raw::c_char) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn _strerror(_ErrorMessage: *const ::std::os::raw::c_char) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn _strerror_s( - _Buffer: *mut ::std::os::raw::c_char, - _SizeInBytes: usize, - _ErrorMessage: *const ::std::os::raw::c_char, - ) -> errno_t; -} -unsafe extern "C" { - pub fn strerror(_ErrorMessage: ::std::os::raw::c_int) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn _stricmp( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _stricoll( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _stricoll_l( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - _Locale: _locale_t, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _stricmp_l( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - _Locale: _locale_t, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn strlen(_Str: *const ::std::os::raw::c_char) -> ::std::os::raw::c_ulonglong; -} -unsafe extern "C" { - pub fn _strlwr_s(_String: *mut ::std::os::raw::c_char, _Size: usize) -> errno_t; -} -unsafe extern "C" { - pub fn _strlwr(_String: *mut ::std::os::raw::c_char) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn _strlwr_s_l( - _String: *mut ::std::os::raw::c_char, - _Size: usize, - _Locale: _locale_t, - ) -> errno_t; -} -unsafe extern "C" { - pub fn _strlwr_l( - _String: *mut ::std::os::raw::c_char, - _Locale: _locale_t, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strncat( - _Destination: *mut ::std::os::raw::c_char, - _Source: *const ::std::os::raw::c_char, - _Count: ::std::os::raw::c_ulonglong, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strncmp( - _Str1: *const ::std::os::raw::c_char, - _Str2: *const ::std::os::raw::c_char, - _MaxCount: ::std::os::raw::c_ulonglong, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _strnicmp( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - _MaxCount: usize, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _strnicmp_l( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - _MaxCount: usize, - _Locale: _locale_t, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _strnicoll( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - _MaxCount: usize, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _strnicoll_l( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - _MaxCount: usize, - _Locale: _locale_t, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _strncoll( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - _MaxCount: usize, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn _strncoll_l( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - _MaxCount: usize, - _Locale: _locale_t, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn __strncnt(_String: *const ::std::os::raw::c_char, _Count: usize) -> usize; -} -unsafe extern "C" { - pub fn strncpy( - _Destination: *mut ::std::os::raw::c_char, - _Source: *const ::std::os::raw::c_char, - _Count: ::std::os::raw::c_ulonglong, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strnlen(_String: *const ::std::os::raw::c_char, _MaxCount: usize) -> usize; -} -unsafe extern "C" { - pub fn _strnset_s( - _String: *mut ::std::os::raw::c_char, - _SizeInBytes: usize, - _Value: ::std::os::raw::c_int, - _MaxCount: usize, - ) -> errno_t; -} -unsafe extern "C" { - pub fn _strnset( - _Destination: *mut ::std::os::raw::c_char, - _Value: ::std::os::raw::c_int, - _Count: usize, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strpbrk( - _Str: *const ::std::os::raw::c_char, - _Control: *const ::std::os::raw::c_char, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn _strrev(_Str: *mut ::std::os::raw::c_char) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn _strset_s( - _Destination: *mut ::std::os::raw::c_char, - _DestinationSize: usize, - _Value: ::std::os::raw::c_int, - ) -> errno_t; -} -unsafe extern "C" { - pub fn _strset( - _Destination: *mut ::std::os::raw::c_char, - _Value: ::std::os::raw::c_int, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strspn( - _Str: *const ::std::os::raw::c_char, - _Control: *const ::std::os::raw::c_char, - ) -> ::std::os::raw::c_ulonglong; -} -unsafe extern "C" { - pub fn strtok( - _String: *mut ::std::os::raw::c_char, - _Delimiter: *const ::std::os::raw::c_char, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn _strupr_s(_String: *mut ::std::os::raw::c_char, _Size: usize) -> errno_t; -} -unsafe extern "C" { - pub fn _strupr(_String: *mut ::std::os::raw::c_char) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn _strupr_s_l( - _String: *mut ::std::os::raw::c_char, - _Size: usize, - _Locale: _locale_t, - ) -> errno_t; -} -unsafe extern "C" { - pub fn _strupr_l( - _String: *mut ::std::os::raw::c_char, - _Locale: _locale_t, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strxfrm( - _Destination: *mut ::std::os::raw::c_char, - _Source: *const ::std::os::raw::c_char, - _MaxCount: ::std::os::raw::c_ulonglong, - ) -> ::std::os::raw::c_ulonglong; -} -unsafe extern "C" { - pub fn _strxfrm_l( - _Destination: *mut ::std::os::raw::c_char, - _Source: *const ::std::os::raw::c_char, - _MaxCount: usize, - _Locale: _locale_t, - ) -> usize; -} -unsafe extern "C" { - pub fn strdup(_String: *const ::std::os::raw::c_char) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strcmpi( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn stricmp( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn strlwr(_String: *mut ::std::os::raw::c_char) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strnicmp( - _String1: *const ::std::os::raw::c_char, - _String2: *const ::std::os::raw::c_char, - _MaxCount: usize, - ) -> ::std::os::raw::c_int; -} -unsafe extern "C" { - pub fn strnset( - _String: *mut ::std::os::raw::c_char, - _Value: ::std::os::raw::c_int, - _MaxCount: usize, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strrev(_String: *mut ::std::os::raw::c_char) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strset( - _String: *mut ::std::os::raw::c_char, - _Value: ::std::os::raw::c_int, - ) -> *mut ::std::os::raw::c_char; -} -unsafe extern "C" { - pub fn strupr(_String: *mut ::std::os::raw::c_char) -> *mut ::std::os::raw::c_char; -} -#[repr(i32)] -#[non_exhaustive] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub enum rng_type_t { - STD_DEFAULT_RNG = 0, - CUDA_RNG = 1, -} -#[repr(i32)] -#[non_exhaustive] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub enum sample_method_t { - EULER_A = 0, - EULER = 1, - HEUN = 2, - DPM2 = 3, - DPMPP2S_A = 4, - DPMPP2M = 5, - DPMPP2Mv2 = 6, - IPNDM = 7, - IPNDM_V = 8, - LCM = 9, - N_SAMPLE_METHODS = 10, -} -#[repr(i32)] -#[non_exhaustive] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub enum schedule_t { - DEFAULT = 0, - DISCRETE = 1, - KARRAS = 2, - EXPONENTIAL = 3, - AYS = 4, - GITS = 5, - N_SCHEDULES = 6, -} -#[repr(i32)] -#[non_exhaustive] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub enum sd_type_t { - SD_TYPE_F32 = 0, - SD_TYPE_F16 = 1, - SD_TYPE_Q4_0 = 2, - SD_TYPE_Q4_1 = 3, - SD_TYPE_Q5_0 = 6, - SD_TYPE_Q5_1 = 7, - SD_TYPE_Q8_0 = 8, - SD_TYPE_Q8_1 = 9, - SD_TYPE_Q2_K = 10, - SD_TYPE_Q3_K = 11, - SD_TYPE_Q4_K = 12, - SD_TYPE_Q5_K = 13, - SD_TYPE_Q6_K = 14, - SD_TYPE_Q8_K = 15, - SD_TYPE_IQ2_XXS = 16, - SD_TYPE_IQ2_XS = 17, - SD_TYPE_IQ3_XXS = 18, - SD_TYPE_IQ1_S = 19, - SD_TYPE_IQ4_NL = 20, - SD_TYPE_IQ3_S = 21, - SD_TYPE_IQ2_S = 22, - SD_TYPE_IQ4_XS = 23, - SD_TYPE_I8 = 24, - SD_TYPE_I16 = 25, - SD_TYPE_I32 = 26, - SD_TYPE_I64 = 27, - SD_TYPE_F64 = 28, - SD_TYPE_IQ1_M = 29, - SD_TYPE_BF16 = 30, - SD_TYPE_Q4_0_4_4 = 31, - SD_TYPE_Q4_0_4_8 = 32, - SD_TYPE_Q4_0_8_8 = 33, - SD_TYPE_TQ1_0 = 34, - SD_TYPE_TQ2_0 = 35, - SD_TYPE_COUNT = 36, -} -unsafe extern "C" { - pub fn sd_type_name(type_: sd_type_t) -> *const ::std::os::raw::c_char; -} -#[repr(i32)] -#[non_exhaustive] -#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] -pub enum sd_log_level_t { - SD_LOG_DEBUG = 0, - SD_LOG_INFO = 1, - SD_LOG_WARN = 2, - SD_LOG_ERROR = 3, -} -pub type sd_log_cb_t = ::std::option::Option< - unsafe extern "C" fn( - level: sd_log_level_t, - text: *const ::std::os::raw::c_char, - data: *mut ::std::os::raw::c_void, - ), ->; -pub type sd_progress_cb_t = ::std::option::Option< - unsafe extern "C" fn( - step: ::std::os::raw::c_int, - steps: ::std::os::raw::c_int, - time: f32, - data: *mut ::std::os::raw::c_void, - ), ->; -unsafe extern "C" { - pub fn sd_set_log_callback(sd_log_cb: sd_log_cb_t, data: *mut ::std::os::raw::c_void); -} -unsafe extern "C" { - pub fn sd_set_progress_callback(cb: sd_progress_cb_t, data: *mut ::std::os::raw::c_void); -} -unsafe extern "C" { - pub fn get_num_physical_cores() -> i32; -} -unsafe extern "C" { - pub fn sd_get_system_info() -> *const ::std::os::raw::c_char; -} -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct sd_image_t { - pub width: u32, - pub height: u32, - pub channel: u32, - pub data: *mut u8, -} -#[allow(clippy::unnecessary_operation, clippy::identity_op)] -const _: () = { - ["Size of sd_image_t"][::std::mem::size_of::() - 24usize]; - ["Alignment of sd_image_t"][::std::mem::align_of::() - 8usize]; - ["Offset of field: sd_image_t::width"][::std::mem::offset_of!(sd_image_t, width) - 0usize]; - ["Offset of field: sd_image_t::height"][::std::mem::offset_of!(sd_image_t, height) - 4usize]; - ["Offset of field: sd_image_t::channel"][::std::mem::offset_of!(sd_image_t, channel) - 8usize]; - ["Offset of field: sd_image_t::data"][::std::mem::offset_of!(sd_image_t, data) - 16usize]; -}; -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct sd_ctx_t { - _unused: [u8; 0], -} -unsafe extern "C" { - pub fn new_sd_ctx( - model_path: *const ::std::os::raw::c_char, - clip_l_path: *const ::std::os::raw::c_char, - clip_g_path: *const ::std::os::raw::c_char, - t5xxl_path: *const ::std::os::raw::c_char, - diffusion_model_path: *const ::std::os::raw::c_char, - vae_path: *const ::std::os::raw::c_char, - taesd_path: *const ::std::os::raw::c_char, - control_net_path_c_str: *const ::std::os::raw::c_char, - lora_model_dir: *const ::std::os::raw::c_char, - embed_dir_c_str: *const ::std::os::raw::c_char, - stacked_id_embed_dir_c_str: *const ::std::os::raw::c_char, - vae_decode_only: bool, - vae_tiling: bool, - free_params_immediately: bool, - n_threads: ::std::os::raw::c_int, - wtype: sd_type_t, - rng_type: rng_type_t, - s: schedule_t, - keep_clip_on_cpu: bool, - keep_control_net_cpu: bool, - keep_vae_on_cpu: bool, - diffusion_flash_attn: bool, - ) -> *mut sd_ctx_t; -} -unsafe extern "C" { - pub fn free_sd_ctx(sd_ctx: *mut sd_ctx_t); -} -unsafe extern "C" { - pub fn txt2img( - sd_ctx: *mut sd_ctx_t, - prompt: *const ::std::os::raw::c_char, - negative_prompt: *const ::std::os::raw::c_char, - clip_skip: ::std::os::raw::c_int, - cfg_scale: f32, - guidance: f32, - width: ::std::os::raw::c_int, - height: ::std::os::raw::c_int, - sample_method: sample_method_t, - sample_steps: ::std::os::raw::c_int, - seed: i64, - batch_count: ::std::os::raw::c_int, - control_cond: *const sd_image_t, - control_strength: f32, - style_strength: f32, - normalize_input: bool, - input_id_images_path: *const ::std::os::raw::c_char, - skip_layers: *mut ::std::os::raw::c_int, - skip_layers_count: usize, - slg_scale: f32, - skip_layer_start: f32, - skip_layer_end: f32, - ) -> *mut sd_image_t; -} -unsafe extern "C" { - pub fn img2img( - sd_ctx: *mut sd_ctx_t, - init_image: sd_image_t, - mask_image: sd_image_t, - prompt: *const ::std::os::raw::c_char, - negative_prompt: *const ::std::os::raw::c_char, - clip_skip: ::std::os::raw::c_int, - cfg_scale: f32, - guidance: f32, - width: ::std::os::raw::c_int, - height: ::std::os::raw::c_int, - sample_method: sample_method_t, - sample_steps: ::std::os::raw::c_int, - strength: f32, - seed: i64, - batch_count: ::std::os::raw::c_int, - control_cond: *const sd_image_t, - control_strength: f32, - style_strength: f32, - normalize_input: bool, - input_id_images_path: *const ::std::os::raw::c_char, - skip_layers: *mut ::std::os::raw::c_int, - skip_layers_count: usize, - slg_scale: f32, - skip_layer_start: f32, - skip_layer_end: f32, - ) -> *mut sd_image_t; -} -unsafe extern "C" { - pub fn img2vid( - sd_ctx: *mut sd_ctx_t, - init_image: sd_image_t, - width: ::std::os::raw::c_int, - height: ::std::os::raw::c_int, - video_frames: ::std::os::raw::c_int, - motion_bucket_id: ::std::os::raw::c_int, - fps: ::std::os::raw::c_int, - augmentation_level: f32, - min_cfg: f32, - cfg_scale: f32, - sample_method: sample_method_t, - sample_steps: ::std::os::raw::c_int, - strength: f32, - seed: i64, - ) -> *mut sd_image_t; -} -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct upscaler_ctx_t { - _unused: [u8; 0], -} -unsafe extern "C" { - pub fn new_upscaler_ctx( - esrgan_path: *const ::std::os::raw::c_char, - n_threads: ::std::os::raw::c_int, - ) -> *mut upscaler_ctx_t; -} -unsafe extern "C" { - pub fn free_upscaler_ctx(upscaler_ctx: *mut upscaler_ctx_t); -} -unsafe extern "C" { - pub fn upscale( - upscaler_ctx: *mut upscaler_ctx_t, - input_image: sd_image_t, - upscale_factor: u32, - ) -> sd_image_t; -} -unsafe extern "C" { - pub fn convert( - input_path: *const ::std::os::raw::c_char, - vae_path: *const ::std::os::raw::c_char, - output_path: *const ::std::os::raw::c_char, - output_type: sd_type_t, - ) -> bool; -} -unsafe extern "C" { - pub fn preprocess_canny( - img: *mut u8, - width: ::std::os::raw::c_int, - height: ::std::os::raw::c_int, - high_threshold: f32, - low_threshold: f32, - weak: f32, - strong: f32, - inverse: bool, - ) -> *mut u8; -} -unsafe extern "C" { - pub fn stbi_write_png_custom( - filename: *const ::std::os::raw::c_char, - w: ::std::os::raw::c_int, - h: ::std::os::raw::c_int, - comp: ::std::os::raw::c_int, - data: *const ::std::os::raw::c_void, - stride_in_bytes: ::std::os::raw::c_int, - ) -> ::std::os::raw::c_int; -} -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct __crt_locale_data { - pub _address: u8, -} -#[repr(C)] -#[derive(Debug, Copy, Clone)] -pub struct __crt_multibyte_data { - pub _address: u8, -} diff --git a/sys/stable-diffusion.cpp b/sys/stable-diffusion.cpp index d9b5942..30b3ac8 160000 --- a/sys/stable-diffusion.cpp +++ b/sys/stable-diffusion.cpp @@ -1 +1 @@ -Subproject commit d9b5942d988ee36c2f2d8a2d79820e90110947c3 +Subproject commit 30b3ac8e6279c128a7a3f8c2627d31b96c2a1185 diff --git a/sys/stb_image_write.c b/sys/stb_image_write.c deleted file mode 100644 index ae0acc9..0000000 --- a/sys/stb_image_write.c +++ /dev/null @@ -1,6 +0,0 @@ -#define STB_IMAGE_WRITE_IMPLEMENTATION -#include "stb_image_write.h" - -int stbi_write_png_custom(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes) { - return stbi_write_png(filename, w, h, comp, data, stride_in_bytes, NULL); -} diff --git a/sys/wrapper.h b/sys/wrapper.h index 74eb554..f112757 100644 --- a/sys/wrapper.h +++ b/sys/wrapper.h @@ -1,3 +1 @@ -#include - -int stbi_write_png_custom(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes); +#include \ No newline at end of file From 8186fbf1cfdfa6818dd422183ef2b13b4de0165a Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Wed, 19 Mar 2025 10:44:27 -0400 Subject: [PATCH 25/33] add binding and clean readme --- README.md | 16 +- src/utils.rs | 40 +- sys/src/bindings.rs | 1353 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 1384 insertions(+), 25 deletions(-) create mode 100644 sys/src/bindings.rs diff --git a/README.md b/README.md index e066cd7..4bcc10a 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,12 @@ # diffusion-rs + [![Latest version](https://img.shields.io/crates/v/diffusion-rs.svg)](https://crates.io/crates/diffusion-rs) [![Documentation](https://docs.rs/diffusion-rs/badge.svg)](https://docs.rs/diffusion-rs) Rust bindings to ## Features Matrix + | | Windows | Mac | Linux | | --- | :---: | :---: | :---: | |vulkan| ✅️ | ⛓️‍💥 | ✅️ | @@ -17,7 +19,8 @@ Rust bindings to ⛓️‍💥 : Issues when linking libraries -## Usage +## Usage + ``` rust no_run use diffusion_rs::{api::txt2img, preset::{Preset,PresetBuilder}}; let config = PresetBuilder::default() @@ -31,16 +34,17 @@ txt2img(config).unwrap(); ## Troubleshooting * Something other than Windows/Linux isn't working! - * I don't have a way to test these platforms, so I can't really help you. + * I don't have a way to test these platforms, so I can't really help you. * I get a panic during binding generation build! - * You can attempt to fix it yourself, or you can set the `DIFFUSION_SKIP_BINDINGS` environment variable. + * You can attempt to fix it yourself, or you can set the `DIFFUSION_SKIP_BINDINGS` environment variable. This skips attempting to build the bindings whatsoever and copies the existing ones. They may be out of date, but it's better than nothing. - * `DIFFUSION_SKIP_BINDINGS=1 cargo build` - * If you can fix the issue, please open a PR! + * `DIFFUSION_SKIP_BINDINGS=1 cargo build` + * If you can fix the issue, please open a PR! ## Roadmap + 1. ~~Ensure that the underline cpp library compiles on supported platforms~~ 2. ~~Build an easy to use library with model presets~~ 3. ~~Automatic library publishing on crates.io by gh actions~~ -4. _Maybe_ prebuilt CLI app binaries \ No newline at end of file +4. _Maybe_ prebuilt CLI app binaries diff --git a/src/utils.rs b/src/utils.rs index 7ed9263..31124a1 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -4,10 +4,10 @@ use diffusion_rs_sys::sd_set_progress_callback; use image::ImageBuffer; use image::Rgb; use image::RgbImage; -use std::ffi::c_char; -use std::ffi::c_void; use std::ffi::CStr; use std::ffi::CString; +use std::ffi::c_char; +use std::ffi::c_void; use std::path::PathBuf; use std::slice; use thiserror::Error; @@ -80,12 +80,30 @@ extern "C" fn default_log_callback(level: sd_log_level_t, text: *const c_char, _ unsafe { // Convert C string to Rust &str and print it. if !text.is_null() { - let msg = CStr::from_ptr(text).to_str().unwrap_or("Invalid UTF-8"); + let msg = CStr::from_ptr(text) + .to_str() + .unwrap_or("LOG ERROR: Invalid UTF-8"); print!("({:?}): {}", level, msg); } } } +pub fn setup_logging( + log_callback: Option, + progress_callback: Option, +) { + unsafe { + match log_callback { + Some(callback) => sd_set_log_callback(Some(callback), std::ptr::null_mut()), + None => sd_set_log_callback(Some(default_log_callback), std::ptr::null_mut()), + }; + match progress_callback { + Some(callback) => sd_set_progress_callback(Some(callback), std::ptr::null_mut()), + None => (), + }; + } +} + // use std::sync::LazyLock; // static BAR: LazyLock> = LazyLock::new(|| { @@ -112,19 +130,3 @@ extern "C" fn default_log_callback(level: sd_log_level_t, text: *const c_char, _ // bar.set_message(format!("Elapsed: {:.2} s", time)); // } // } - -pub fn setup_logging( - log_callback: Option, - progress_callback: Option, -) { - unsafe { - match log_callback { - Some(callback) => sd_set_log_callback(Some(callback), std::ptr::null_mut()), - None => sd_set_log_callback(Some(default_log_callback), std::ptr::null_mut()), - }; - match progress_callback { - Some(callback) => sd_set_progress_callback(Some(callback), std::ptr::null_mut()), - None => (), - }; - } -} diff --git a/sys/src/bindings.rs b/sys/src/bindings.rs new file mode 100644 index 0000000..7f8b28e --- /dev/null +++ b/sys/src/bindings.rs @@ -0,0 +1,1353 @@ +/* automatically generated by rust-bindgen 0.71.1 */ + +pub const __bool_true_false_are_defined: u32 = 1; +pub const true_: u32 = 1; +pub const false_: u32 = 0; +pub const __WORDSIZE: u32 = 64; +pub const __has_safe_buffers: u32 = 1; +pub const __DARWIN_ONLY_64_BIT_INO_T: u32 = 1; +pub const __DARWIN_ONLY_UNIX_CONFORMANCE: u32 = 1; +pub const __DARWIN_ONLY_VERS_1050: u32 = 1; +pub const __DARWIN_UNIX03: u32 = 1; +pub const __DARWIN_64_BIT_INO_T: u32 = 1; +pub const __DARWIN_VERS_1050: u32 = 1; +pub const __DARWIN_NON_CANCELABLE: u32 = 0; +pub const __DARWIN_SUF_EXTSN: &[u8; 14] = b"$DARWIN_EXTSN\0"; +pub const __DARWIN_C_ANSI: u32 = 4096; +pub const __DARWIN_C_FULL: u32 = 900000; +pub const __DARWIN_C_LEVEL: u32 = 900000; +pub const __STDC_WANT_LIB_EXT1__: u32 = 1; +pub const __DARWIN_NO_LONG_LONG: u32 = 0; +pub const _DARWIN_FEATURE_64_BIT_INODE: u32 = 1; +pub const _DARWIN_FEATURE_ONLY_64_BIT_INODE: u32 = 1; +pub const _DARWIN_FEATURE_ONLY_VERS_1050: u32 = 1; +pub const _DARWIN_FEATURE_ONLY_UNIX_CONFORMANCE: u32 = 1; +pub const _DARWIN_FEATURE_UNIX_CONFORMANCE: u32 = 3; +pub const __has_ptrcheck: u32 = 0; +pub const USE_CLANG_TYPES: u32 = 0; +pub const __PTHREAD_SIZE__: u32 = 8176; +pub const __PTHREAD_ATTR_SIZE__: u32 = 56; +pub const __PTHREAD_MUTEXATTR_SIZE__: u32 = 8; +pub const __PTHREAD_MUTEX_SIZE__: u32 = 56; +pub const __PTHREAD_CONDATTR_SIZE__: u32 = 8; +pub const __PTHREAD_COND_SIZE__: u32 = 40; +pub const __PTHREAD_ONCE_SIZE__: u32 = 8; +pub const __PTHREAD_RWLOCK_SIZE__: u32 = 192; +pub const __PTHREAD_RWLOCKATTR_SIZE__: u32 = 16; +pub const INT8_MAX: u32 = 127; +pub const INT16_MAX: u32 = 32767; +pub const INT32_MAX: u32 = 2147483647; +pub const INT64_MAX: u64 = 9223372036854775807; +pub const INT8_MIN: i32 = -128; +pub const INT16_MIN: i32 = -32768; +pub const INT32_MIN: i32 = -2147483648; +pub const INT64_MIN: i64 = -9223372036854775808; +pub const UINT8_MAX: u32 = 255; +pub const UINT16_MAX: u32 = 65535; +pub const UINT32_MAX: u32 = 4294967295; +pub const UINT64_MAX: i32 = -1; +pub const INT_LEAST8_MIN: i32 = -128; +pub const INT_LEAST16_MIN: i32 = -32768; +pub const INT_LEAST32_MIN: i32 = -2147483648; +pub const INT_LEAST64_MIN: i64 = -9223372036854775808; +pub const INT_LEAST8_MAX: u32 = 127; +pub const INT_LEAST16_MAX: u32 = 32767; +pub const INT_LEAST32_MAX: u32 = 2147483647; +pub const INT_LEAST64_MAX: u64 = 9223372036854775807; +pub const UINT_LEAST8_MAX: u32 = 255; +pub const UINT_LEAST16_MAX: u32 = 65535; +pub const UINT_LEAST32_MAX: u32 = 4294967295; +pub const UINT_LEAST64_MAX: i32 = -1; +pub const INT_FAST8_MIN: i32 = -128; +pub const INT_FAST16_MIN: i32 = -32768; +pub const INT_FAST32_MIN: i32 = -2147483648; +pub const INT_FAST64_MIN: i64 = -9223372036854775808; +pub const INT_FAST8_MAX: u32 = 127; +pub const INT_FAST16_MAX: u32 = 32767; +pub const INT_FAST32_MAX: u32 = 2147483647; +pub const INT_FAST64_MAX: u64 = 9223372036854775807; +pub const UINT_FAST8_MAX: u32 = 255; +pub const UINT_FAST16_MAX: u32 = 65535; +pub const UINT_FAST32_MAX: u32 = 4294967295; +pub const UINT_FAST64_MAX: i32 = -1; +pub const INTPTR_MAX: u64 = 9223372036854775807; +pub const INTPTR_MIN: i64 = -9223372036854775808; +pub const UINTPTR_MAX: i32 = -1; +pub const SIZE_MAX: i32 = -1; +pub const RSIZE_MAX: i32 = -1; +pub const WINT_MIN: i32 = -2147483648; +pub const WINT_MAX: u32 = 2147483647; +pub const SIG_ATOMIC_MIN: i32 = -2147483648; +pub const SIG_ATOMIC_MAX: u32 = 2147483647; +pub const __DARWIN_WCHAR_MIN: i32 = -2147483648; +pub const _FORTIFY_SOURCE: u32 = 2; +pub const __API_TO_BE_DEPRECATED: u32 = 100000; +pub const __API_TO_BE_DEPRECATED_MACOS: u32 = 100000; +pub const __API_TO_BE_DEPRECATED_IOS: u32 = 100000; +pub const __API_TO_BE_DEPRECATED_MACCATALYST: u32 = 100000; +pub const __API_TO_BE_DEPRECATED_WATCHOS: u32 = 100000; +pub const __API_TO_BE_DEPRECATED_TVOS: u32 = 100000; +pub const __API_TO_BE_DEPRECATED_DRIVERKIT: u32 = 100000; +pub const __API_TO_BE_DEPRECATED_VISIONOS: u32 = 100000; +pub const __MAC_10_0: u32 = 1000; +pub const __MAC_10_1: u32 = 1010; +pub const __MAC_10_2: u32 = 1020; +pub const __MAC_10_3: u32 = 1030; +pub const __MAC_10_4: u32 = 1040; +pub const __MAC_10_5: u32 = 1050; +pub const __MAC_10_6: u32 = 1060; +pub const __MAC_10_7: u32 = 1070; +pub const __MAC_10_8: u32 = 1080; +pub const __MAC_10_9: u32 = 1090; +pub const __MAC_10_10: u32 = 101000; +pub const __MAC_10_10_2: u32 = 101002; +pub const __MAC_10_10_3: u32 = 101003; +pub const __MAC_10_11: u32 = 101100; +pub const __MAC_10_11_2: u32 = 101102; +pub const __MAC_10_11_3: u32 = 101103; +pub const __MAC_10_11_4: u32 = 101104; +pub const __MAC_10_12: u32 = 101200; +pub const __MAC_10_12_1: u32 = 101201; +pub const __MAC_10_12_2: u32 = 101202; +pub const __MAC_10_12_4: u32 = 101204; +pub const __MAC_10_13: u32 = 101300; +pub const __MAC_10_13_1: u32 = 101301; +pub const __MAC_10_13_2: u32 = 101302; +pub const __MAC_10_13_4: u32 = 101304; +pub const __MAC_10_14: u32 = 101400; +pub const __MAC_10_14_1: u32 = 101401; +pub const __MAC_10_14_4: u32 = 101404; +pub const __MAC_10_14_5: u32 = 101405; +pub const __MAC_10_14_6: u32 = 101406; +pub const __MAC_10_15: u32 = 101500; +pub const __MAC_10_15_1: u32 = 101501; +pub const __MAC_10_15_4: u32 = 101504; +pub const __MAC_10_16: u32 = 101600; +pub const __MAC_11_0: u32 = 110000; +pub const __MAC_11_1: u32 = 110100; +pub const __MAC_11_3: u32 = 110300; +pub const __MAC_11_4: u32 = 110400; +pub const __MAC_11_5: u32 = 110500; +pub const __MAC_11_6: u32 = 110600; +pub const __MAC_12_0: u32 = 120000; +pub const __MAC_12_1: u32 = 120100; +pub const __MAC_12_2: u32 = 120200; +pub const __MAC_12_3: u32 = 120300; +pub const __MAC_12_4: u32 = 120400; +pub const __MAC_12_5: u32 = 120500; +pub const __MAC_12_6: u32 = 120600; +pub const __MAC_12_7: u32 = 120700; +pub const __MAC_13_0: u32 = 130000; +pub const __MAC_13_1: u32 = 130100; +pub const __MAC_13_2: u32 = 130200; +pub const __MAC_13_3: u32 = 130300; +pub const __MAC_13_4: u32 = 130400; +pub const __MAC_13_5: u32 = 130500; +pub const __MAC_13_6: u32 = 130600; +pub const __MAC_14_0: u32 = 140000; +pub const __MAC_14_1: u32 = 140100; +pub const __MAC_14_2: u32 = 140200; +pub const __MAC_14_3: u32 = 140300; +pub const __MAC_14_4: u32 = 140400; +pub const __MAC_14_5: u32 = 140500; +pub const __MAC_15_0: u32 = 150000; +pub const __MAC_15_1: u32 = 150100; +pub const __MAC_15_2: u32 = 150200; +pub const __IPHONE_2_0: u32 = 20000; +pub const __IPHONE_2_1: u32 = 20100; +pub const __IPHONE_2_2: u32 = 20200; +pub const __IPHONE_3_0: u32 = 30000; +pub const __IPHONE_3_1: u32 = 30100; +pub const __IPHONE_3_2: u32 = 30200; +pub const __IPHONE_4_0: u32 = 40000; +pub const __IPHONE_4_1: u32 = 40100; +pub const __IPHONE_4_2: u32 = 40200; +pub const __IPHONE_4_3: u32 = 40300; +pub const __IPHONE_5_0: u32 = 50000; +pub const __IPHONE_5_1: u32 = 50100; +pub const __IPHONE_6_0: u32 = 60000; +pub const __IPHONE_6_1: u32 = 60100; +pub const __IPHONE_7_0: u32 = 70000; +pub const __IPHONE_7_1: u32 = 70100; +pub const __IPHONE_8_0: u32 = 80000; +pub const __IPHONE_8_1: u32 = 80100; +pub const __IPHONE_8_2: u32 = 80200; +pub const __IPHONE_8_3: u32 = 80300; +pub const __IPHONE_8_4: u32 = 80400; +pub const __IPHONE_9_0: u32 = 90000; +pub const __IPHONE_9_1: u32 = 90100; +pub const __IPHONE_9_2: u32 = 90200; +pub const __IPHONE_9_3: u32 = 90300; +pub const __IPHONE_10_0: u32 = 100000; +pub const __IPHONE_10_1: u32 = 100100; +pub const __IPHONE_10_2: u32 = 100200; +pub const __IPHONE_10_3: u32 = 100300; +pub const __IPHONE_11_0: u32 = 110000; +pub const __IPHONE_11_1: u32 = 110100; +pub const __IPHONE_11_2: u32 = 110200; +pub const __IPHONE_11_3: u32 = 110300; +pub const __IPHONE_11_4: u32 = 110400; +pub const __IPHONE_12_0: u32 = 120000; +pub const __IPHONE_12_1: u32 = 120100; +pub const __IPHONE_12_2: u32 = 120200; +pub const __IPHONE_12_3: u32 = 120300; +pub const __IPHONE_12_4: u32 = 120400; +pub const __IPHONE_13_0: u32 = 130000; +pub const __IPHONE_13_1: u32 = 130100; +pub const __IPHONE_13_2: u32 = 130200; +pub const __IPHONE_13_3: u32 = 130300; +pub const __IPHONE_13_4: u32 = 130400; +pub const __IPHONE_13_5: u32 = 130500; +pub const __IPHONE_13_6: u32 = 130600; +pub const __IPHONE_13_7: u32 = 130700; +pub const __IPHONE_14_0: u32 = 140000; +pub const __IPHONE_14_1: u32 = 140100; +pub const __IPHONE_14_2: u32 = 140200; +pub const __IPHONE_14_3: u32 = 140300; +pub const __IPHONE_14_5: u32 = 140500; +pub const __IPHONE_14_4: u32 = 140400; +pub const __IPHONE_14_6: u32 = 140600; +pub const __IPHONE_14_7: u32 = 140700; +pub const __IPHONE_14_8: u32 = 140800; +pub const __IPHONE_15_0: u32 = 150000; +pub const __IPHONE_15_1: u32 = 150100; +pub const __IPHONE_15_2: u32 = 150200; +pub const __IPHONE_15_3: u32 = 150300; +pub const __IPHONE_15_4: u32 = 150400; +pub const __IPHONE_15_5: u32 = 150500; +pub const __IPHONE_15_6: u32 = 150600; +pub const __IPHONE_15_7: u32 = 150700; +pub const __IPHONE_15_8: u32 = 150800; +pub const __IPHONE_16_0: u32 = 160000; +pub const __IPHONE_16_1: u32 = 160100; +pub const __IPHONE_16_2: u32 = 160200; +pub const __IPHONE_16_3: u32 = 160300; +pub const __IPHONE_16_4: u32 = 160400; +pub const __IPHONE_16_5: u32 = 160500; +pub const __IPHONE_16_6: u32 = 160600; +pub const __IPHONE_16_7: u32 = 160700; +pub const __IPHONE_17_0: u32 = 170000; +pub const __IPHONE_17_1: u32 = 170100; +pub const __IPHONE_17_2: u32 = 170200; +pub const __IPHONE_17_3: u32 = 170300; +pub const __IPHONE_17_4: u32 = 170400; +pub const __IPHONE_17_5: u32 = 170500; +pub const __IPHONE_18_0: u32 = 180000; +pub const __IPHONE_18_1: u32 = 180100; +pub const __IPHONE_18_2: u32 = 180200; +pub const __WATCHOS_1_0: u32 = 10000; +pub const __WATCHOS_2_0: u32 = 20000; +pub const __WATCHOS_2_1: u32 = 20100; +pub const __WATCHOS_2_2: u32 = 20200; +pub const __WATCHOS_3_0: u32 = 30000; +pub const __WATCHOS_3_1: u32 = 30100; +pub const __WATCHOS_3_1_1: u32 = 30101; +pub const __WATCHOS_3_2: u32 = 30200; +pub const __WATCHOS_4_0: u32 = 40000; +pub const __WATCHOS_4_1: u32 = 40100; +pub const __WATCHOS_4_2: u32 = 40200; +pub const __WATCHOS_4_3: u32 = 40300; +pub const __WATCHOS_5_0: u32 = 50000; +pub const __WATCHOS_5_1: u32 = 50100; +pub const __WATCHOS_5_2: u32 = 50200; +pub const __WATCHOS_5_3: u32 = 50300; +pub const __WATCHOS_6_0: u32 = 60000; +pub const __WATCHOS_6_1: u32 = 60100; +pub const __WATCHOS_6_2: u32 = 60200; +pub const __WATCHOS_7_0: u32 = 70000; +pub const __WATCHOS_7_1: u32 = 70100; +pub const __WATCHOS_7_2: u32 = 70200; +pub const __WATCHOS_7_3: u32 = 70300; +pub const __WATCHOS_7_4: u32 = 70400; +pub const __WATCHOS_7_5: u32 = 70500; +pub const __WATCHOS_7_6: u32 = 70600; +pub const __WATCHOS_8_0: u32 = 80000; +pub const __WATCHOS_8_1: u32 = 80100; +pub const __WATCHOS_8_3: u32 = 80300; +pub const __WATCHOS_8_4: u32 = 80400; +pub const __WATCHOS_8_5: u32 = 80500; +pub const __WATCHOS_8_6: u32 = 80600; +pub const __WATCHOS_8_7: u32 = 80700; +pub const __WATCHOS_8_8: u32 = 80800; +pub const __WATCHOS_9_0: u32 = 90000; +pub const __WATCHOS_9_1: u32 = 90100; +pub const __WATCHOS_9_2: u32 = 90200; +pub const __WATCHOS_9_3: u32 = 90300; +pub const __WATCHOS_9_4: u32 = 90400; +pub const __WATCHOS_9_5: u32 = 90500; +pub const __WATCHOS_9_6: u32 = 90600; +pub const __WATCHOS_10_0: u32 = 100000; +pub const __WATCHOS_10_1: u32 = 100100; +pub const __WATCHOS_10_2: u32 = 100200; +pub const __WATCHOS_10_3: u32 = 100300; +pub const __WATCHOS_10_4: u32 = 100400; +pub const __WATCHOS_10_5: u32 = 100500; +pub const __WATCHOS_11_0: u32 = 110000; +pub const __WATCHOS_11_1: u32 = 110100; +pub const __WATCHOS_11_2: u32 = 110200; +pub const __TVOS_9_0: u32 = 90000; +pub const __TVOS_9_1: u32 = 90100; +pub const __TVOS_9_2: u32 = 90200; +pub const __TVOS_10_0: u32 = 100000; +pub const __TVOS_10_0_1: u32 = 100001; +pub const __TVOS_10_1: u32 = 100100; +pub const __TVOS_10_2: u32 = 100200; +pub const __TVOS_11_0: u32 = 110000; +pub const __TVOS_11_1: u32 = 110100; +pub const __TVOS_11_2: u32 = 110200; +pub const __TVOS_11_3: u32 = 110300; +pub const __TVOS_11_4: u32 = 110400; +pub const __TVOS_12_0: u32 = 120000; +pub const __TVOS_12_1: u32 = 120100; +pub const __TVOS_12_2: u32 = 120200; +pub const __TVOS_12_3: u32 = 120300; +pub const __TVOS_12_4: u32 = 120400; +pub const __TVOS_13_0: u32 = 130000; +pub const __TVOS_13_2: u32 = 130200; +pub const __TVOS_13_3: u32 = 130300; +pub const __TVOS_13_4: u32 = 130400; +pub const __TVOS_14_0: u32 = 140000; +pub const __TVOS_14_1: u32 = 140100; +pub const __TVOS_14_2: u32 = 140200; +pub const __TVOS_14_3: u32 = 140300; +pub const __TVOS_14_5: u32 = 140500; +pub const __TVOS_14_6: u32 = 140600; +pub const __TVOS_14_7: u32 = 140700; +pub const __TVOS_15_0: u32 = 150000; +pub const __TVOS_15_1: u32 = 150100; +pub const __TVOS_15_2: u32 = 150200; +pub const __TVOS_15_3: u32 = 150300; +pub const __TVOS_15_4: u32 = 150400; +pub const __TVOS_15_5: u32 = 150500; +pub const __TVOS_15_6: u32 = 150600; +pub const __TVOS_16_0: u32 = 160000; +pub const __TVOS_16_1: u32 = 160100; +pub const __TVOS_16_2: u32 = 160200; +pub const __TVOS_16_3: u32 = 160300; +pub const __TVOS_16_4: u32 = 160400; +pub const __TVOS_16_5: u32 = 160500; +pub const __TVOS_16_6: u32 = 160600; +pub const __TVOS_17_0: u32 = 170000; +pub const __TVOS_17_1: u32 = 170100; +pub const __TVOS_17_2: u32 = 170200; +pub const __TVOS_17_3: u32 = 170300; +pub const __TVOS_17_4: u32 = 170400; +pub const __TVOS_17_5: u32 = 170500; +pub const __TVOS_18_0: u32 = 180000; +pub const __TVOS_18_1: u32 = 180100; +pub const __TVOS_18_2: u32 = 180200; +pub const __BRIDGEOS_2_0: u32 = 20000; +pub const __BRIDGEOS_3_0: u32 = 30000; +pub const __BRIDGEOS_3_1: u32 = 30100; +pub const __BRIDGEOS_3_4: u32 = 30400; +pub const __BRIDGEOS_4_0: u32 = 40000; +pub const __BRIDGEOS_4_1: u32 = 40100; +pub const __BRIDGEOS_5_0: u32 = 50000; +pub const __BRIDGEOS_5_1: u32 = 50100; +pub const __BRIDGEOS_5_3: u32 = 50300; +pub const __BRIDGEOS_6_0: u32 = 60000; +pub const __BRIDGEOS_6_2: u32 = 60200; +pub const __BRIDGEOS_6_4: u32 = 60400; +pub const __BRIDGEOS_6_5: u32 = 60500; +pub const __BRIDGEOS_6_6: u32 = 60600; +pub const __BRIDGEOS_7_0: u32 = 70000; +pub const __BRIDGEOS_7_1: u32 = 70100; +pub const __BRIDGEOS_7_2: u32 = 70200; +pub const __BRIDGEOS_7_3: u32 = 70300; +pub const __BRIDGEOS_7_4: u32 = 70400; +pub const __BRIDGEOS_7_6: u32 = 70600; +pub const __BRIDGEOS_8_0: u32 = 80000; +pub const __BRIDGEOS_8_1: u32 = 80100; +pub const __BRIDGEOS_8_2: u32 = 80200; +pub const __BRIDGEOS_8_3: u32 = 80300; +pub const __BRIDGEOS_8_4: u32 = 80400; +pub const __BRIDGEOS_8_5: u32 = 80500; +pub const __BRIDGEOS_9_0: u32 = 90000; +pub const __BRIDGEOS_9_1: u32 = 90100; +pub const __BRIDGEOS_9_2: u32 = 90200; +pub const __DRIVERKIT_19_0: u32 = 190000; +pub const __DRIVERKIT_20_0: u32 = 200000; +pub const __DRIVERKIT_21_0: u32 = 210000; +pub const __DRIVERKIT_22_0: u32 = 220000; +pub const __DRIVERKIT_22_4: u32 = 220400; +pub const __DRIVERKIT_22_5: u32 = 220500; +pub const __DRIVERKIT_22_6: u32 = 220600; +pub const __DRIVERKIT_23_0: u32 = 230000; +pub const __DRIVERKIT_23_1: u32 = 230100; +pub const __DRIVERKIT_23_2: u32 = 230200; +pub const __DRIVERKIT_23_3: u32 = 230300; +pub const __DRIVERKIT_23_4: u32 = 230400; +pub const __DRIVERKIT_23_5: u32 = 230500; +pub const __DRIVERKIT_24_0: u32 = 240000; +pub const __DRIVERKIT_24_1: u32 = 240100; +pub const __DRIVERKIT_24_2: u32 = 240200; +pub const __VISIONOS_1_0: u32 = 10000; +pub const __VISIONOS_1_1: u32 = 10100; +pub const __VISIONOS_1_2: u32 = 10200; +pub const __VISIONOS_2_0: u32 = 20000; +pub const __VISIONOS_2_1: u32 = 20100; +pub const __VISIONOS_2_2: u32 = 20200; +pub const MAC_OS_X_VERSION_10_0: u32 = 1000; +pub const MAC_OS_X_VERSION_10_1: u32 = 1010; +pub const MAC_OS_X_VERSION_10_2: u32 = 1020; +pub const MAC_OS_X_VERSION_10_3: u32 = 1030; +pub const MAC_OS_X_VERSION_10_4: u32 = 1040; +pub const MAC_OS_X_VERSION_10_5: u32 = 1050; +pub const MAC_OS_X_VERSION_10_6: u32 = 1060; +pub const MAC_OS_X_VERSION_10_7: u32 = 1070; +pub const MAC_OS_X_VERSION_10_8: u32 = 1080; +pub const MAC_OS_X_VERSION_10_9: u32 = 1090; +pub const MAC_OS_X_VERSION_10_10: u32 = 101000; +pub const MAC_OS_X_VERSION_10_10_2: u32 = 101002; +pub const MAC_OS_X_VERSION_10_10_3: u32 = 101003; +pub const MAC_OS_X_VERSION_10_11: u32 = 101100; +pub const MAC_OS_X_VERSION_10_11_2: u32 = 101102; +pub const MAC_OS_X_VERSION_10_11_3: u32 = 101103; +pub const MAC_OS_X_VERSION_10_11_4: u32 = 101104; +pub const MAC_OS_X_VERSION_10_12: u32 = 101200; +pub const MAC_OS_X_VERSION_10_12_1: u32 = 101201; +pub const MAC_OS_X_VERSION_10_12_2: u32 = 101202; +pub const MAC_OS_X_VERSION_10_12_4: u32 = 101204; +pub const MAC_OS_X_VERSION_10_13: u32 = 101300; +pub const MAC_OS_X_VERSION_10_13_1: u32 = 101301; +pub const MAC_OS_X_VERSION_10_13_2: u32 = 101302; +pub const MAC_OS_X_VERSION_10_13_4: u32 = 101304; +pub const MAC_OS_X_VERSION_10_14: u32 = 101400; +pub const MAC_OS_X_VERSION_10_14_1: u32 = 101401; +pub const MAC_OS_X_VERSION_10_14_4: u32 = 101404; +pub const MAC_OS_X_VERSION_10_14_5: u32 = 101405; +pub const MAC_OS_X_VERSION_10_14_6: u32 = 101406; +pub const MAC_OS_X_VERSION_10_15: u32 = 101500; +pub const MAC_OS_X_VERSION_10_15_1: u32 = 101501; +pub const MAC_OS_X_VERSION_10_15_4: u32 = 101504; +pub const MAC_OS_X_VERSION_10_16: u32 = 101600; +pub const MAC_OS_VERSION_11_0: u32 = 110000; +pub const MAC_OS_VERSION_11_1: u32 = 110100; +pub const MAC_OS_VERSION_11_3: u32 = 110300; +pub const MAC_OS_VERSION_11_4: u32 = 110400; +pub const MAC_OS_VERSION_11_5: u32 = 110500; +pub const MAC_OS_VERSION_11_6: u32 = 110600; +pub const MAC_OS_VERSION_12_0: u32 = 120000; +pub const MAC_OS_VERSION_12_1: u32 = 120100; +pub const MAC_OS_VERSION_12_2: u32 = 120200; +pub const MAC_OS_VERSION_12_3: u32 = 120300; +pub const MAC_OS_VERSION_12_4: u32 = 120400; +pub const MAC_OS_VERSION_12_5: u32 = 120500; +pub const MAC_OS_VERSION_12_6: u32 = 120600; +pub const MAC_OS_VERSION_12_7: u32 = 120700; +pub const MAC_OS_VERSION_13_0: u32 = 130000; +pub const MAC_OS_VERSION_13_1: u32 = 130100; +pub const MAC_OS_VERSION_13_2: u32 = 130200; +pub const MAC_OS_VERSION_13_3: u32 = 130300; +pub const MAC_OS_VERSION_13_4: u32 = 130400; +pub const MAC_OS_VERSION_13_5: u32 = 130500; +pub const MAC_OS_VERSION_13_6: u32 = 130600; +pub const MAC_OS_VERSION_14_0: u32 = 140000; +pub const MAC_OS_VERSION_14_1: u32 = 140100; +pub const MAC_OS_VERSION_14_2: u32 = 140200; +pub const MAC_OS_VERSION_14_3: u32 = 140300; +pub const MAC_OS_VERSION_14_4: u32 = 140400; +pub const MAC_OS_VERSION_14_5: u32 = 140500; +pub const MAC_OS_VERSION_15_0: u32 = 150000; +pub const MAC_OS_VERSION_15_1: u32 = 150100; +pub const MAC_OS_VERSION_15_2: u32 = 150200; +pub const __MAC_OS_X_VERSION_MAX_ALLOWED: u32 = 150200; +pub const __ENABLE_LEGACY_MAC_AVAILABILITY: u32 = 1; +pub const USE_CLANG_STDDEF: u32 = 0; +pub const _USE_FORTIFY_LEVEL: u32 = 2; +pub const __HAS_FIXED_CHK_PROTOTYPES: u32 = 1; +pub type wchar_t = ::std::os::raw::c_int; +pub type max_align_t = f64; +pub type int_least8_t = i8; +pub type int_least16_t = i16; +pub type int_least32_t = i32; +pub type int_least64_t = i64; +pub type uint_least8_t = u8; +pub type uint_least16_t = u16; +pub type uint_least32_t = u32; +pub type uint_least64_t = u64; +pub type int_fast8_t = i8; +pub type int_fast16_t = i16; +pub type int_fast32_t = i32; +pub type int_fast64_t = i64; +pub type uint_fast8_t = u8; +pub type uint_fast16_t = u16; +pub type uint_fast32_t = u32; +pub type uint_fast64_t = u64; +pub type __int8_t = ::std::os::raw::c_schar; +pub type __uint8_t = ::std::os::raw::c_uchar; +pub type __int16_t = ::std::os::raw::c_short; +pub type __uint16_t = ::std::os::raw::c_ushort; +pub type __int32_t = ::std::os::raw::c_int; +pub type __uint32_t = ::std::os::raw::c_uint; +pub type __int64_t = ::std::os::raw::c_longlong; +pub type __uint64_t = ::std::os::raw::c_ulonglong; +pub type __darwin_intptr_t = ::std::os::raw::c_long; +pub type __darwin_natural_t = ::std::os::raw::c_uint; +pub type __darwin_ct_rune_t = ::std::os::raw::c_int; +#[repr(C)] +#[derive(Copy, Clone)] +pub union __mbstate_t { + pub __mbstate8: [::std::os::raw::c_char; 128usize], + pub _mbstateL: ::std::os::raw::c_longlong, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of __mbstate_t"][::std::mem::size_of::<__mbstate_t>() - 128usize]; + ["Alignment of __mbstate_t"][::std::mem::align_of::<__mbstate_t>() - 8usize]; + ["Offset of field: __mbstate_t::__mbstate8"] + [::std::mem::offset_of!(__mbstate_t, __mbstate8) - 0usize]; + ["Offset of field: __mbstate_t::_mbstateL"] + [::std::mem::offset_of!(__mbstate_t, _mbstateL) - 0usize]; +}; +pub type __darwin_mbstate_t = __mbstate_t; +pub type __darwin_ptrdiff_t = ::std::os::raw::c_long; +pub type __darwin_size_t = ::std::os::raw::c_ulong; +pub type __darwin_va_list = __builtin_va_list; +pub type __darwin_wchar_t = ::std::os::raw::c_int; +pub type __darwin_rune_t = __darwin_wchar_t; +pub type __darwin_wint_t = ::std::os::raw::c_int; +pub type __darwin_clock_t = ::std::os::raw::c_ulong; +pub type __darwin_socklen_t = __uint32_t; +pub type __darwin_ssize_t = ::std::os::raw::c_long; +pub type __darwin_time_t = ::std::os::raw::c_long; +pub type __darwin_blkcnt_t = __int64_t; +pub type __darwin_blksize_t = __int32_t; +pub type __darwin_dev_t = __int32_t; +pub type __darwin_fsblkcnt_t = ::std::os::raw::c_uint; +pub type __darwin_fsfilcnt_t = ::std::os::raw::c_uint; +pub type __darwin_gid_t = __uint32_t; +pub type __darwin_id_t = __uint32_t; +pub type __darwin_ino64_t = __uint64_t; +pub type __darwin_ino_t = __darwin_ino64_t; +pub type __darwin_mach_port_name_t = __darwin_natural_t; +pub type __darwin_mach_port_t = __darwin_mach_port_name_t; +pub type __darwin_mode_t = __uint16_t; +pub type __darwin_off_t = __int64_t; +pub type __darwin_pid_t = __int32_t; +pub type __darwin_sigset_t = __uint32_t; +pub type __darwin_suseconds_t = __int32_t; +pub type __darwin_uid_t = __uint32_t; +pub type __darwin_useconds_t = __uint32_t; +pub type __darwin_uuid_t = [::std::os::raw::c_uchar; 16usize]; +pub type __darwin_uuid_string_t = [::std::os::raw::c_char; 37usize]; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct __darwin_pthread_handler_rec { + pub __routine: ::std::option::Option, + pub __arg: *mut ::std::os::raw::c_void, + pub __next: *mut __darwin_pthread_handler_rec, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of __darwin_pthread_handler_rec"] + [::std::mem::size_of::<__darwin_pthread_handler_rec>() - 24usize]; + ["Alignment of __darwin_pthread_handler_rec"] + [::std::mem::align_of::<__darwin_pthread_handler_rec>() - 8usize]; + ["Offset of field: __darwin_pthread_handler_rec::__routine"] + [::std::mem::offset_of!(__darwin_pthread_handler_rec, __routine) - 0usize]; + ["Offset of field: __darwin_pthread_handler_rec::__arg"] + [::std::mem::offset_of!(__darwin_pthread_handler_rec, __arg) - 8usize]; + ["Offset of field: __darwin_pthread_handler_rec::__next"] + [::std::mem::offset_of!(__darwin_pthread_handler_rec, __next) - 16usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_attr_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 56usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_attr_t"][::std::mem::size_of::<_opaque_pthread_attr_t>() - 64usize]; + ["Alignment of _opaque_pthread_attr_t"] + [::std::mem::align_of::<_opaque_pthread_attr_t>() - 8usize]; + ["Offset of field: _opaque_pthread_attr_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_attr_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_attr_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_attr_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_cond_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 40usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_cond_t"][::std::mem::size_of::<_opaque_pthread_cond_t>() - 48usize]; + ["Alignment of _opaque_pthread_cond_t"] + [::std::mem::align_of::<_opaque_pthread_cond_t>() - 8usize]; + ["Offset of field: _opaque_pthread_cond_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_cond_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_cond_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_cond_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_condattr_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 8usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_condattr_t"] + [::std::mem::size_of::<_opaque_pthread_condattr_t>() - 16usize]; + ["Alignment of _opaque_pthread_condattr_t"] + [::std::mem::align_of::<_opaque_pthread_condattr_t>() - 8usize]; + ["Offset of field: _opaque_pthread_condattr_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_condattr_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_condattr_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_condattr_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_mutex_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 56usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_mutex_t"][::std::mem::size_of::<_opaque_pthread_mutex_t>() - 64usize]; + ["Alignment of _opaque_pthread_mutex_t"] + [::std::mem::align_of::<_opaque_pthread_mutex_t>() - 8usize]; + ["Offset of field: _opaque_pthread_mutex_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_mutex_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_mutex_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_mutex_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_mutexattr_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 8usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_mutexattr_t"] + [::std::mem::size_of::<_opaque_pthread_mutexattr_t>() - 16usize]; + ["Alignment of _opaque_pthread_mutexattr_t"] + [::std::mem::align_of::<_opaque_pthread_mutexattr_t>() - 8usize]; + ["Offset of field: _opaque_pthread_mutexattr_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_mutexattr_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_mutexattr_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_mutexattr_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_once_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 8usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_once_t"][::std::mem::size_of::<_opaque_pthread_once_t>() - 16usize]; + ["Alignment of _opaque_pthread_once_t"] + [::std::mem::align_of::<_opaque_pthread_once_t>() - 8usize]; + ["Offset of field: _opaque_pthread_once_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_once_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_once_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_once_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_rwlock_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 192usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_rwlock_t"] + [::std::mem::size_of::<_opaque_pthread_rwlock_t>() - 200usize]; + ["Alignment of _opaque_pthread_rwlock_t"] + [::std::mem::align_of::<_opaque_pthread_rwlock_t>() - 8usize]; + ["Offset of field: _opaque_pthread_rwlock_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_rwlock_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_rwlock_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_rwlock_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_rwlockattr_t { + pub __sig: ::std::os::raw::c_long, + pub __opaque: [::std::os::raw::c_char; 16usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_rwlockattr_t"] + [::std::mem::size_of::<_opaque_pthread_rwlockattr_t>() - 24usize]; + ["Alignment of _opaque_pthread_rwlockattr_t"] + [::std::mem::align_of::<_opaque_pthread_rwlockattr_t>() - 8usize]; + ["Offset of field: _opaque_pthread_rwlockattr_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_rwlockattr_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_rwlockattr_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_rwlockattr_t, __opaque) - 8usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _opaque_pthread_t { + pub __sig: ::std::os::raw::c_long, + pub __cleanup_stack: *mut __darwin_pthread_handler_rec, + pub __opaque: [::std::os::raw::c_char; 8176usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of _opaque_pthread_t"][::std::mem::size_of::<_opaque_pthread_t>() - 8192usize]; + ["Alignment of _opaque_pthread_t"][::std::mem::align_of::<_opaque_pthread_t>() - 8usize]; + ["Offset of field: _opaque_pthread_t::__sig"] + [::std::mem::offset_of!(_opaque_pthread_t, __sig) - 0usize]; + ["Offset of field: _opaque_pthread_t::__cleanup_stack"] + [::std::mem::offset_of!(_opaque_pthread_t, __cleanup_stack) - 8usize]; + ["Offset of field: _opaque_pthread_t::__opaque"] + [::std::mem::offset_of!(_opaque_pthread_t, __opaque) - 16usize]; +}; +pub type __darwin_pthread_attr_t = _opaque_pthread_attr_t; +pub type __darwin_pthread_cond_t = _opaque_pthread_cond_t; +pub type __darwin_pthread_condattr_t = _opaque_pthread_condattr_t; +pub type __darwin_pthread_key_t = ::std::os::raw::c_ulong; +pub type __darwin_pthread_mutex_t = _opaque_pthread_mutex_t; +pub type __darwin_pthread_mutexattr_t = _opaque_pthread_mutexattr_t; +pub type __darwin_pthread_once_t = _opaque_pthread_once_t; +pub type __darwin_pthread_rwlock_t = _opaque_pthread_rwlock_t; +pub type __darwin_pthread_rwlockattr_t = _opaque_pthread_rwlockattr_t; +pub type __darwin_pthread_t = *mut _opaque_pthread_t; +pub type intmax_t = ::std::os::raw::c_long; +pub type uintmax_t = ::std::os::raw::c_ulong; +pub type __darwin_nl_item = ::std::os::raw::c_int; +pub type __darwin_wctrans_t = ::std::os::raw::c_int; +pub type __darwin_wctype_t = __uint32_t; +unsafe extern "C" { + pub fn memchr( + __s: *const ::std::os::raw::c_void, + __c: ::std::os::raw::c_int, + __n: ::std::os::raw::c_ulong, + ) -> *mut ::std::os::raw::c_void; +} +unsafe extern "C" { + pub fn memcmp( + __s1: *const ::std::os::raw::c_void, + __s2: *const ::std::os::raw::c_void, + __n: ::std::os::raw::c_ulong, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn memcpy( + __dst: *mut ::std::os::raw::c_void, + __src: *const ::std::os::raw::c_void, + __n: ::std::os::raw::c_ulong, + ) -> *mut ::std::os::raw::c_void; +} +unsafe extern "C" { + pub fn memmove( + __dst: *mut ::std::os::raw::c_void, + __src: *const ::std::os::raw::c_void, + __len: ::std::os::raw::c_ulong, + ) -> *mut ::std::os::raw::c_void; +} +unsafe extern "C" { + pub fn memset( + __b: *mut ::std::os::raw::c_void, + __c: ::std::os::raw::c_int, + __len: ::std::os::raw::c_ulong, + ) -> *mut ::std::os::raw::c_void; +} +unsafe extern "C" { + pub fn strcat( + __s1: *mut ::std::os::raw::c_char, + __s2: *const ::std::os::raw::c_char, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strchr( + __s: *const ::std::os::raw::c_char, + __c: ::std::os::raw::c_int, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strcmp( + __s1: *const ::std::os::raw::c_char, + __s2: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn strcoll( + __s1: *const ::std::os::raw::c_char, + __s2: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn strcpy( + __dst: *mut ::std::os::raw::c_char, + __src: *const ::std::os::raw::c_char, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strcspn( + __s: *const ::std::os::raw::c_char, + __charset: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_ulong; +} +unsafe extern "C" { + pub fn strerror(__errnum: ::std::os::raw::c_int) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strlen(__s: *const ::std::os::raw::c_char) -> ::std::os::raw::c_ulong; +} +unsafe extern "C" { + pub fn strncat( + __s1: *mut ::std::os::raw::c_char, + __s2: *const ::std::os::raw::c_char, + __n: ::std::os::raw::c_ulong, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strncmp( + __s1: *const ::std::os::raw::c_char, + __s2: *const ::std::os::raw::c_char, + __n: ::std::os::raw::c_ulong, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn strncpy( + __dst: *mut ::std::os::raw::c_char, + __src: *const ::std::os::raw::c_char, + __n: ::std::os::raw::c_ulong, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strpbrk( + __s: *const ::std::os::raw::c_char, + __charset: *const ::std::os::raw::c_char, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strrchr( + __s: *const ::std::os::raw::c_char, + __c: ::std::os::raw::c_int, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strspn( + __s: *const ::std::os::raw::c_char, + __charset: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_ulong; +} +unsafe extern "C" { + pub fn strstr( + __big: *const ::std::os::raw::c_char, + __little: *const ::std::os::raw::c_char, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strtok( + __str: *mut ::std::os::raw::c_char, + __sep: *const ::std::os::raw::c_char, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strxfrm( + __s1: *mut ::std::os::raw::c_char, + __s2: *const ::std::os::raw::c_char, + __n: ::std::os::raw::c_ulong, + ) -> ::std::os::raw::c_ulong; +} +unsafe extern "C" { + pub fn strtok_r( + __str: *mut ::std::os::raw::c_char, + __sep: *const ::std::os::raw::c_char, + __lasts: *mut *mut ::std::os::raw::c_char, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strerror_r( + __errnum: ::std::os::raw::c_int, + __strerrbuf: *mut ::std::os::raw::c_char, + __buflen: usize, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn strdup(__s1: *const ::std::os::raw::c_char) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn memccpy( + __dst: *mut ::std::os::raw::c_void, + __src: *const ::std::os::raw::c_void, + __c: ::std::os::raw::c_int, + __n: ::std::os::raw::c_ulong, + ) -> *mut ::std::os::raw::c_void; +} +unsafe extern "C" { + pub fn stpcpy( + __dst: *mut ::std::os::raw::c_char, + __src: *const ::std::os::raw::c_char, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn stpncpy( + __dst: *mut ::std::os::raw::c_char, + __src: *const ::std::os::raw::c_char, + __n: ::std::os::raw::c_ulong, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strndup( + __s1: *const ::std::os::raw::c_char, + __n: ::std::os::raw::c_ulong, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strnlen(__s1: *const ::std::os::raw::c_char, __n: usize) -> usize; +} +unsafe extern "C" { + pub fn strsignal(__sig: ::std::os::raw::c_int) -> *mut ::std::os::raw::c_char; +} +pub type u_int8_t = ::std::os::raw::c_uchar; +pub type u_int16_t = ::std::os::raw::c_ushort; +pub type u_int32_t = ::std::os::raw::c_uint; +pub type u_int64_t = ::std::os::raw::c_ulonglong; +pub type register_t = i64; +pub type user_addr_t = u_int64_t; +pub type user_size_t = u_int64_t; +pub type user_ssize_t = i64; +pub type user_long_t = i64; +pub type user_ulong_t = u_int64_t; +pub type user_time_t = i64; +pub type user_off_t = i64; +pub type syscall_arg_t = u_int64_t; +pub type rsize_t = __darwin_size_t; +pub type errno_t = ::std::os::raw::c_int; +unsafe extern "C" { + pub fn memset_s( + __s: *mut ::std::os::raw::c_void, + __smax: rsize_t, + __c: ::std::os::raw::c_int, + __n: rsize_t, + ) -> errno_t; +} +unsafe extern "C" { + pub fn memmem( + __big: *const ::std::os::raw::c_void, + __big_len: usize, + __little: *const ::std::os::raw::c_void, + __little_len: usize, + ) -> *mut ::std::os::raw::c_void; +} +unsafe extern "C" { + pub fn memset_pattern4( + __b: *mut ::std::os::raw::c_void, + __pattern4: *const ::std::os::raw::c_void, + __len: usize, + ); +} +unsafe extern "C" { + pub fn memset_pattern8( + __b: *mut ::std::os::raw::c_void, + __pattern8: *const ::std::os::raw::c_void, + __len: usize, + ); +} +unsafe extern "C" { + pub fn memset_pattern16( + __b: *mut ::std::os::raw::c_void, + __pattern16: *const ::std::os::raw::c_void, + __len: usize, + ); +} +unsafe extern "C" { + pub fn strcasestr( + __big: *const ::std::os::raw::c_char, + __little: *const ::std::os::raw::c_char, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strnstr( + __big: *const ::std::os::raw::c_char, + __little: *const ::std::os::raw::c_char, + __len: usize, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn strlcat( + __dst: *mut ::std::os::raw::c_char, + __source: *const ::std::os::raw::c_char, + __size: ::std::os::raw::c_ulong, + ) -> ::std::os::raw::c_ulong; +} +unsafe extern "C" { + pub fn strlcpy( + __dst: *mut ::std::os::raw::c_char, + __source: *const ::std::os::raw::c_char, + __size: ::std::os::raw::c_ulong, + ) -> ::std::os::raw::c_ulong; +} +unsafe extern "C" { + pub fn strmode(__mode: ::std::os::raw::c_int, __bp: *mut ::std::os::raw::c_char); +} +unsafe extern "C" { + pub fn strsep( + __stringp: *mut *mut ::std::os::raw::c_char, + __delim: *const ::std::os::raw::c_char, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn swab( + arg1: *const ::std::os::raw::c_void, + arg2: *mut ::std::os::raw::c_void, + arg3: isize, + ); +} +unsafe extern "C" { + pub fn timingsafe_bcmp( + __b1: *const ::std::os::raw::c_void, + __b2: *const ::std::os::raw::c_void, + __len: usize, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn strsignal_r( + __sig: ::std::os::raw::c_int, + __strsignalbuf: *mut ::std::os::raw::c_char, + __buflen: usize, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn bcmp( + arg1: *const ::std::os::raw::c_void, + arg2: *const ::std::os::raw::c_void, + arg3: ::std::os::raw::c_ulong, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn bcopy( + arg1: *const ::std::os::raw::c_void, + arg2: *mut ::std::os::raw::c_void, + arg3: usize, + ); +} +unsafe extern "C" { + pub fn bzero(arg1: *mut ::std::os::raw::c_void, arg2: ::std::os::raw::c_ulong); +} +unsafe extern "C" { + pub fn index( + arg1: *const ::std::os::raw::c_char, + arg2: ::std::os::raw::c_int, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn rindex( + arg1: *const ::std::os::raw::c_char, + arg2: ::std::os::raw::c_int, + ) -> *mut ::std::os::raw::c_char; +} +unsafe extern "C" { + pub fn ffs(arg1: ::std::os::raw::c_int) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn strcasecmp( + arg1: *const ::std::os::raw::c_char, + arg2: *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn strncasecmp( + arg1: *const ::std::os::raw::c_char, + arg2: *const ::std::os::raw::c_char, + arg3: ::std::os::raw::c_ulong, + ) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ffsl(arg1: ::std::os::raw::c_long) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn ffsll(arg1: ::std::os::raw::c_longlong) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn fls(arg1: ::std::os::raw::c_int) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn flsl(arg1: ::std::os::raw::c_long) -> ::std::os::raw::c_int; +} +unsafe extern "C" { + pub fn flsll(arg1: ::std::os::raw::c_longlong) -> ::std::os::raw::c_int; +} +#[repr(u32)] +#[non_exhaustive] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum rng_type_t { + STD_DEFAULT_RNG = 0, + CUDA_RNG = 1, +} +#[repr(u32)] +#[non_exhaustive] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum sample_method_t { + EULER_A = 0, + EULER = 1, + HEUN = 2, + DPM2 = 3, + DPMPP2S_A = 4, + DPMPP2M = 5, + DPMPP2Mv2 = 6, + IPNDM = 7, + IPNDM_V = 8, + LCM = 9, + DDIM_TRAILING = 10, + TCD = 11, + N_SAMPLE_METHODS = 12, +} +#[repr(u32)] +#[non_exhaustive] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum schedule_t { + DEFAULT = 0, + DISCRETE = 1, + KARRAS = 2, + EXPONENTIAL = 3, + AYS = 4, + GITS = 5, + N_SCHEDULES = 6, +} +#[repr(u32)] +#[non_exhaustive] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum sd_type_t { + SD_TYPE_F32 = 0, + SD_TYPE_F16 = 1, + SD_TYPE_Q4_0 = 2, + SD_TYPE_Q4_1 = 3, + SD_TYPE_Q5_0 = 6, + SD_TYPE_Q5_1 = 7, + SD_TYPE_Q8_0 = 8, + SD_TYPE_Q8_1 = 9, + SD_TYPE_Q2_K = 10, + SD_TYPE_Q3_K = 11, + SD_TYPE_Q4_K = 12, + SD_TYPE_Q5_K = 13, + SD_TYPE_Q6_K = 14, + SD_TYPE_Q8_K = 15, + SD_TYPE_IQ2_XXS = 16, + SD_TYPE_IQ2_XS = 17, + SD_TYPE_IQ3_XXS = 18, + SD_TYPE_IQ1_S = 19, + SD_TYPE_IQ4_NL = 20, + SD_TYPE_IQ3_S = 21, + SD_TYPE_IQ2_S = 22, + SD_TYPE_IQ4_XS = 23, + SD_TYPE_I8 = 24, + SD_TYPE_I16 = 25, + SD_TYPE_I32 = 26, + SD_TYPE_I64 = 27, + SD_TYPE_F64 = 28, + SD_TYPE_IQ1_M = 29, + SD_TYPE_BF16 = 30, + SD_TYPE_TQ1_0 = 34, + SD_TYPE_TQ2_0 = 35, + SD_TYPE_COUNT = 39, +} +unsafe extern "C" { + pub fn sd_type_name(type_: sd_type_t) -> *const ::std::os::raw::c_char; +} +#[repr(u32)] +#[non_exhaustive] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum sd_log_level_t { + SD_LOG_DEBUG = 0, + SD_LOG_INFO = 1, + SD_LOG_WARN = 2, + SD_LOG_ERROR = 3, +} +pub type sd_log_cb_t = ::std::option::Option< + unsafe extern "C" fn( + level: sd_log_level_t, + text: *const ::std::os::raw::c_char, + data: *mut ::std::os::raw::c_void, + ), +>; +pub type sd_progress_cb_t = ::std::option::Option< + unsafe extern "C" fn( + step: ::std::os::raw::c_int, + steps: ::std::os::raw::c_int, + time: f32, + data: *mut ::std::os::raw::c_void, + ), +>; +unsafe extern "C" { + pub fn sd_set_log_callback(sd_log_cb: sd_log_cb_t, data: *mut ::std::os::raw::c_void); +} +unsafe extern "C" { + pub fn sd_set_progress_callback(cb: sd_progress_cb_t, data: *mut ::std::os::raw::c_void); +} +unsafe extern "C" { + pub fn get_num_physical_cores() -> i32; +} +unsafe extern "C" { + pub fn sd_get_system_info() -> *const ::std::os::raw::c_char; +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct sd_image_t { + pub width: u32, + pub height: u32, + pub channel: u32, + pub data: *mut u8, +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of sd_image_t"][::std::mem::size_of::() - 24usize]; + ["Alignment of sd_image_t"][::std::mem::align_of::() - 8usize]; + ["Offset of field: sd_image_t::width"][::std::mem::offset_of!(sd_image_t, width) - 0usize]; + ["Offset of field: sd_image_t::height"][::std::mem::offset_of!(sd_image_t, height) - 4usize]; + ["Offset of field: sd_image_t::channel"][::std::mem::offset_of!(sd_image_t, channel) - 8usize]; + ["Offset of field: sd_image_t::data"][::std::mem::offset_of!(sd_image_t, data) - 16usize]; +}; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct sd_ctx_t { + _unused: [u8; 0], +} +unsafe extern "C" { + pub fn new_sd_ctx( + model_path: *const ::std::os::raw::c_char, + clip_l_path: *const ::std::os::raw::c_char, + clip_g_path: *const ::std::os::raw::c_char, + t5xxl_path: *const ::std::os::raw::c_char, + diffusion_model_path: *const ::std::os::raw::c_char, + vae_path: *const ::std::os::raw::c_char, + taesd_path: *const ::std::os::raw::c_char, + control_net_path_c_str: *const ::std::os::raw::c_char, + lora_model_dir: *const ::std::os::raw::c_char, + embed_dir_c_str: *const ::std::os::raw::c_char, + stacked_id_embed_dir_c_str: *const ::std::os::raw::c_char, + vae_decode_only: bool, + vae_tiling: bool, + free_params_immediately: bool, + n_threads: ::std::os::raw::c_int, + wtype: sd_type_t, + rng_type: rng_type_t, + s: schedule_t, + keep_clip_on_cpu: bool, + keep_control_net_cpu: bool, + keep_vae_on_cpu: bool, + diffusion_flash_attn: bool, + ) -> *mut sd_ctx_t; +} +unsafe extern "C" { + pub fn free_sd_ctx(sd_ctx: *mut sd_ctx_t); +} +unsafe extern "C" { + pub fn txt2img( + sd_ctx: *mut sd_ctx_t, + prompt: *const ::std::os::raw::c_char, + negative_prompt: *const ::std::os::raw::c_char, + clip_skip: ::std::os::raw::c_int, + cfg_scale: f32, + guidance: f32, + eta: f32, + width: ::std::os::raw::c_int, + height: ::std::os::raw::c_int, + sample_method: sample_method_t, + sample_steps: ::std::os::raw::c_int, + seed: i64, + batch_count: ::std::os::raw::c_int, + control_cond: *const sd_image_t, + control_strength: f32, + style_strength: f32, + normalize_input: bool, + input_id_images_path: *const ::std::os::raw::c_char, + skip_layers: *mut ::std::os::raw::c_int, + skip_layers_count: usize, + slg_scale: f32, + skip_layer_start: f32, + skip_layer_end: f32, + ) -> *mut sd_image_t; +} +unsafe extern "C" { + pub fn img2img( + sd_ctx: *mut sd_ctx_t, + init_image: sd_image_t, + mask_image: sd_image_t, + prompt: *const ::std::os::raw::c_char, + negative_prompt: *const ::std::os::raw::c_char, + clip_skip: ::std::os::raw::c_int, + cfg_scale: f32, + guidance: f32, + eta: f32, + width: ::std::os::raw::c_int, + height: ::std::os::raw::c_int, + sample_method: sample_method_t, + sample_steps: ::std::os::raw::c_int, + strength: f32, + seed: i64, + batch_count: ::std::os::raw::c_int, + control_cond: *const sd_image_t, + control_strength: f32, + style_strength: f32, + normalize_input: bool, + input_id_images_path: *const ::std::os::raw::c_char, + skip_layers: *mut ::std::os::raw::c_int, + skip_layers_count: usize, + slg_scale: f32, + skip_layer_start: f32, + skip_layer_end: f32, + ) -> *mut sd_image_t; +} +unsafe extern "C" { + pub fn img2vid( + sd_ctx: *mut sd_ctx_t, + init_image: sd_image_t, + width: ::std::os::raw::c_int, + height: ::std::os::raw::c_int, + video_frames: ::std::os::raw::c_int, + motion_bucket_id: ::std::os::raw::c_int, + fps: ::std::os::raw::c_int, + augmentation_level: f32, + min_cfg: f32, + cfg_scale: f32, + sample_method: sample_method_t, + sample_steps: ::std::os::raw::c_int, + strength: f32, + seed: i64, + ) -> *mut sd_image_t; +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct upscaler_ctx_t { + _unused: [u8; 0], +} +unsafe extern "C" { + pub fn new_upscaler_ctx( + esrgan_path: *const ::std::os::raw::c_char, + n_threads: ::std::os::raw::c_int, + ) -> *mut upscaler_ctx_t; +} +unsafe extern "C" { + pub fn free_upscaler_ctx(upscaler_ctx: *mut upscaler_ctx_t); +} +unsafe extern "C" { + pub fn upscale( + upscaler_ctx: *mut upscaler_ctx_t, + input_image: sd_image_t, + upscale_factor: u32, + ) -> sd_image_t; +} +unsafe extern "C" { + pub fn convert( + input_path: *const ::std::os::raw::c_char, + vae_path: *const ::std::os::raw::c_char, + output_path: *const ::std::os::raw::c_char, + output_type: sd_type_t, + ) -> bool; +} +unsafe extern "C" { + pub fn preprocess_canny( + img: *mut u8, + width: ::std::os::raw::c_int, + height: ::std::os::raw::c_int, + high_threshold: f32, + low_threshold: f32, + weak: f32, + strong: f32, + inverse: bool, + ) -> *mut u8; +} +pub type __builtin_va_list = *mut ::std::os::raw::c_char; From ae4150cee7ffa0bb8127c238c770a0fb1aef233a Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Fri, 21 Mar 2025 19:35:17 -0400 Subject: [PATCH 26/33] update sdcpp --- Cargo.toml | 3 +-- sys/stable-diffusion.cpp | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 27feb48..0687d37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,12 +18,11 @@ keywords.workspace = true description = "High level API for stable-diffusion.cpp" documentation = "https://docs.rs/diffusion-rs" - [dependencies] derive_builder = "0.20.2" diffusion-rs-sys = { path = "sys", version = "0.1.8" } image = "0.25.5" -libc = "0.2.161" +libc = "0.2.171" num_cpus = "1.16.0" thiserror = "2.0.12" diff --git a/sys/stable-diffusion.cpp b/sys/stable-diffusion.cpp index 30b3ac8..10c6501 160000 --- a/sys/stable-diffusion.cpp +++ b/sys/stable-diffusion.cpp @@ -1 +1 @@ -Subproject commit 30b3ac8e6279c128a7a3f8c2627d31b96c2a1185 +Subproject commit 10c6501bd05a697e014f1bee3a84e5664290c489 From 873e6f193b0c5440e4e505739a43e262cd41144b Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Wed, 26 Mar 2025 20:05:38 -0400 Subject: [PATCH 27/33] add logging data parameter for callbacks --- src/api.rs | 51 ++++++++++++++++++++++++++++++++++++++++++++- src/model_config.rs | 4 ++++ src/utils.rs | 7 ++++--- 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/src/api.rs b/src/api.rs index 1393726..2110126 100644 --- a/src/api.rs +++ b/src/api.rs @@ -9,6 +9,7 @@ use diffusion_rs_sys::sd_image_t; use image::RgbImage; use libc::free; +#[derive(Debug)] pub struct ModelCtx { /// The underlying C context ctx: *mut diffusion_rs_sys::sd_ctx_t, @@ -22,7 +23,11 @@ unsafe impl Sync for ModelCtx {} impl ModelCtx { pub fn new(config: &ModelConfig) -> Result { - setup_logging(config.log_callback, config.progress_callback); + setup_logging( + config.log_callback, + config.progress_callback, + config.logging_data, + ); let ctx = unsafe { let ptr = diffusion_rs_sys::new_sd_ctx( @@ -160,10 +165,45 @@ mod tests { use super::*; use crate::utils::{RngFunction, SampleMethod, Schedule, WeightType}; use crate::{model_config::ModelConfigBuilder, txt2img_config::Txt2ImgConfigBuilder}; + use diffusion_rs_sys::sd_log_level_t; use image::ImageReader; + use std::ffi::{CStr, c_char, c_int}; use std::sync::{Arc, Mutex}; use std::thread; + #[derive(Debug)] + struct State { + state: String, + } + + extern "C" fn testing_log_callback( + level: sd_log_level_t, + text: *const c_char, + _data: *mut c_void, + ) { + unsafe { + // Convert C string to Rust &str and print it. + if !text.is_null() { + let msg = CStr::from_ptr(text) + .to_str() + .unwrap_or("LOG ERROR: Invalid UTF-8"); + print!("({:?}): {}", level, msg); + } + } + } + + extern "C" fn testing_progress_callback( + step: c_int, + steps: c_int, + time: f32, + _data: *mut c_void, + ) { + // Convert C string to Rust &str and print it. + let data: &mut State = unsafe { &mut *(_data as *mut State) }; + data.state = format!("Updated at {}", time); + print!("({:?}): {} {:?}", (step, steps), time, data); + } + #[test] fn test_invalid_model_config() { let config = ModelConfigBuilder::default().build(); @@ -200,6 +240,11 @@ mod tests { #[test] fn test_txt2img_singlethreaded_success() { + let mut state = State { + state: String::from("beginning"), + }; + let state_ptr: *mut c_void = &mut state as *mut _ as *mut c_void; + let model_config = ModelConfigBuilder::default() .model("./models/mistoonAnime_v30.safetensors") .lora_model_dir("./models/loras") @@ -210,6 +255,10 @@ mod tests { .schedule(Schedule::AYS) .vae_decode_only(true) .flash_attention(true) + .logging_data(state_ptr) + .progress_callback( + testing_progress_callback as extern "C" fn(i32, i32, f32, *mut libc::c_void), + ) .build() .expect("Failed to build model config"); diff --git a/src/model_config.rs b/src/model_config.rs index 65f81ca..9d84f24 100644 --- a/src/model_config.rs +++ b/src/model_config.rs @@ -1,5 +1,6 @@ use std::ffi::{c_char, c_void}; use std::path::PathBuf; +use std::ptr::null_mut; use derive_builder::Builder; use diffusion_rs_sys::{get_num_physical_cores, sd_log_level_t}; @@ -113,6 +114,9 @@ pub struct ModelConfig { #[builder(default = "None")] pub progress_callback: Option, + + #[builder(default = "null_mut()")] + pub logging_data: *mut c_void, } impl ModelConfigBuilder { diff --git a/src/utils.rs b/src/utils.rs index 31124a1..d10d60f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -91,14 +91,15 @@ extern "C" fn default_log_callback(level: sd_log_level_t, text: *const c_char, _ pub fn setup_logging( log_callback: Option, progress_callback: Option, + logging_data: *mut c_void, ) { unsafe { match log_callback { - Some(callback) => sd_set_log_callback(Some(callback), std::ptr::null_mut()), - None => sd_set_log_callback(Some(default_log_callback), std::ptr::null_mut()), + Some(callback) => sd_set_log_callback(Some(callback), logging_data), + None => sd_set_log_callback(Some(default_log_callback), logging_data), }; match progress_callback { - Some(callback) => sd_set_progress_callback(Some(callback), std::ptr::null_mut()), + Some(callback) => sd_set_progress_callback(Some(callback), logging_data), None => (), }; } From fd645c8912fb63d31d263a862474edfe2b514bdf Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Wed, 26 Mar 2025 20:54:28 -0400 Subject: [PATCH 28/33] test probably not good code --- src/api.rs | 9 ++++++--- src/model_config.rs | 6 +++--- src/utils.rs | 8 ++++++++ 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/api.rs b/src/api.rs index 2110126..81d663d 100644 --- a/src/api.rs +++ b/src/api.rs @@ -26,7 +26,7 @@ impl ModelCtx { setup_logging( config.log_callback, config.progress_callback, - config.logging_data, + config.logging_data.as_ref().unwrap().data, ); let ctx = unsafe { @@ -163,7 +163,7 @@ impl Drop for ModelCtx { #[cfg(test)] mod tests { use super::*; - use crate::utils::{RngFunction, SampleMethod, Schedule, WeightType}; + use crate::utils::{Data, RngFunction, SampleMethod, Schedule, WeightType}; use crate::{model_config::ModelConfigBuilder, txt2img_config::Txt2ImgConfigBuilder}; use diffusion_rs_sys::sd_log_level_t; use image::ImageReader; @@ -243,7 +243,10 @@ mod tests { let mut state = State { state: String::from("beginning"), }; - let state_ptr: *mut c_void = &mut state as *mut _ as *mut c_void; + + let state_ptr: Data = Data { + data: &mut state as *mut State as *mut c_void, + }; let model_config = ModelConfigBuilder::default() .model("./models/mistoonAnime_v30.safetensors") diff --git a/src/model_config.rs b/src/model_config.rs index 9d84f24..18738e9 100644 --- a/src/model_config.rs +++ b/src/model_config.rs @@ -5,7 +5,7 @@ use std::ptr::null_mut; use derive_builder::Builder; use diffusion_rs_sys::{get_num_physical_cores, sd_log_level_t}; -use crate::utils::{RngFunction, Schedule, WeightType}; +use crate::utils::{Data, RngFunction, Schedule, WeightType}; #[derive(Builder, Debug, Clone)] #[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))] @@ -115,8 +115,8 @@ pub struct ModelConfig { pub progress_callback: Option, - #[builder(default = "null_mut()")] - pub logging_data: *mut c_void, + #[builder(default = "None")] + pub logging_data: Option, } impl ModelConfigBuilder { diff --git a/src/utils.rs b/src/utils.rs index d10d60f..43d916f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -105,6 +105,14 @@ pub fn setup_logging( } } +#[derive(Debug, Clone)] +pub struct Data { + pub data: *mut c_void, +} + +unsafe impl Send for Data {} +unsafe impl Sync for Data {} + // use std::sync::LazyLock; // static BAR: LazyLock> = LazyLock::new(|| { From 289734c2ce9e20948abae46d4b811e154d902aac Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Thu, 27 Mar 2025 14:06:08 -0400 Subject: [PATCH 29/33] clean up --- src/api.rs | 32 ++++++++++++++++++++++++++++++-- src/model_config.rs | 6 ++---- src/utils.rs | 37 +++++++------------------------------ 3 files changed, 39 insertions(+), 36 deletions(-) diff --git a/src/api.rs b/src/api.rs index 81d663d..29384fa 100644 --- a/src/api.rs +++ b/src/api.rs @@ -168,6 +168,7 @@ mod tests { use diffusion_rs_sys::sd_log_level_t; use image::ImageReader; use std::ffi::{CStr, c_char, c_int}; + use std::sync::mpsc::{self, Sender}; use std::sync::{Arc, Mutex}; use std::thread; @@ -204,6 +205,19 @@ mod tests { print!("({:?}): {} {:?}", (step, steps), time, data); } + extern "C" fn testing_progress_callback2( + step: c_int, + steps: c_int, + time: f32, + _data: *mut c_void, + ) { + // Convert C string to Rust &str and print it. + let data: &Sender = unsafe { &*(_data as *const Sender) }; + data.send(format!("Updated at {}", time)) + .expect("Failed to send progress update"); + print!("({:?}): {} {:?}", (step, steps), time, data); + } + #[test] fn test_invalid_model_config() { let config = ModelConfigBuilder::default().build(); @@ -295,7 +309,7 @@ mod tests { .height(resolution) .width(resolution) .clip_skip(2) - .batch_count(5) + .batch_count(1) .build() .expect("Failed to build txt2img config 1"); @@ -311,6 +325,12 @@ mod tests { #[test] fn test_txt2img_multithreaded_success() { + let (sender, receiver) = mpsc::channel(); + + let sender_ptr: Data = Data { + data: &mut sender.clone() as *mut Sender as *mut c_void, + }; + let model_config = ModelConfigBuilder::default() .model("./models/mistoonAnime_v30.safetensors") .lora_model_dir("./models/loras") @@ -321,6 +341,10 @@ mod tests { .schedule(Schedule::AYS) .vae_decode_only(true) .flash_attention(false) + .logging_data(sender_ptr) + .progress_callback( + testing_progress_callback2 as extern "C" fn(i32, i32, f32, *mut c_void), + ) .build() .expect("Failed to build model config"); @@ -370,8 +394,12 @@ mod tests { .expect("Failed to build txt2img config"); let ctx = Arc::clone(&ctx); - + let sender_clone = sender.clone(); let handle = thread::spawn(move || { + sender_clone + .send(format!("Thread {} started", index)) + .expect("Failed to send message"); + let result = ctx .lock() .unwrap() diff --git a/src/model_config.rs b/src/model_config.rs index 18738e9..0e1012d 100644 --- a/src/model_config.rs +++ b/src/model_config.rs @@ -1,9 +1,7 @@ -use std::ffi::{c_char, c_void}; -use std::path::PathBuf; -use std::ptr::null_mut; - use derive_builder::Builder; use diffusion_rs_sys::{get_num_physical_cores, sd_log_level_t}; +use std::ffi::{c_char, c_void}; +use std::path::PathBuf; use crate::utils::{Data, RngFunction, Schedule, WeightType}; diff --git a/src/utils.rs b/src/utils.rs index 43d916f..fd47188 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -88,9 +88,14 @@ extern "C" fn default_log_callback(level: sd_log_level_t, text: *const c_char, _ } } +type UnsafeLogCallbackFn = + extern "C" fn(level: sd_log_level_t, text: *const c_char, data: *mut c_void); + +type UnsafeProgressCallbackFn = extern "C" fn(step: i32, steps: i32, time: f32, data: *mut c_void); + pub fn setup_logging( - log_callback: Option, - progress_callback: Option, + log_callback: Option, + progress_callback: Option, logging_data: *mut c_void, ) { unsafe { @@ -111,31 +116,3 @@ pub struct Data { } unsafe impl Send for Data {} -unsafe impl Sync for Data {} - -// use std::sync::LazyLock; - -// static BAR: LazyLock> = LazyLock::new(|| { -// Mutex::new(ProgressBar::no_length().with_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})") -// .unwrap() -// .progress_chars("#>-"))) -// }); - -// /// This is your C callback that gets called with current progress. -// extern "C" fn default_progress_callback(step: c_int, steps: c_int, time: f32, _data: *mut c_void) { -// // Update the global progress bar if it's been initialized. -// let mut bar = BAR.lock().unwrap(); - -// if bar.is_finished() { -// *bar = ProgressBar::no_length().with_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})") -// .unwrap() -// .progress_chars("#>-")); -// } else { -// if steps == step { -// bar.finish_with_message("Done"); -// } -// bar.set_length(steps as u64); -// bar.set_position(step as u64); -// bar.set_message(format!("Elapsed: {:.2} s", time)); -// } -// } From b2c36105cbc105cf3837f96948b8da4e4afe6a4d Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Fri, 28 Mar 2025 13:18:20 -0400 Subject: [PATCH 30/33] fix and support closures for log callbacks --- Cargo.toml | 3 +- src/api.rs | 418 ++----------------------------------- src/lib.rs | 3 +- src/model_config.rs | 83 ++++++-- src/txt2img_config.rs | 21 +- src/types.rs | 41 ++++ src/utils.rs | 93 +-------- tests/integration_tests.rs | 278 ++++++++++++++++++++++++ 8 files changed, 429 insertions(+), 511 deletions(-) create mode 100644 src/types.rs create mode 100644 tests/integration_tests.rs diff --git a/Cargo.toml b/Cargo.toml index 0687d37..275ba71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,8 @@ documentation = "https://docs.rs/diffusion-rs" [dependencies] derive_builder = "0.20.2" diffusion-rs-sys = { path = "sys", version = "0.1.8" } -image = "0.25.5" +ffi_helpers = { git = "https://github.com/anzbert/ffi_helpers.git", branch = "master" } +image = "0.25.6" libc = "0.2.171" num_cpus = "1.16.0" thiserror = "2.0.12" diff --git a/src/api.rs b/src/api.rs index 29384fa..c937fc4 100644 --- a/src/api.rs +++ b/src/api.rs @@ -4,7 +4,9 @@ use std::slice; use crate::model_config::ModelConfig; use crate::txt2img_config::Txt2ImgConfig; -use crate::utils::{DiffusionError, convert_image, pathbuf_to_c_char, setup_logging}; +use crate::types::DiffusionError; +use crate::utils::{convert_image, pathbuf_to_c_char}; + use diffusion_rs_sys::sd_image_t; use image::RgbImage; use libc::free; @@ -23,11 +25,16 @@ unsafe impl Sync for ModelCtx {} impl ModelCtx { pub fn new(config: &ModelConfig) -> Result { - setup_logging( - config.log_callback, - config.progress_callback, - config.logging_data.as_ref().unwrap().data, - ); + unsafe { + diffusion_rs_sys::sd_set_log_callback(config.log_callback.0, config.log_callback.1) + }; + + unsafe { + diffusion_rs_sys::sd_set_progress_callback( + config.progress_callback.0, + config.progress_callback.1, + ) + }; let ctx = unsafe { let ptr = diffusion_rs_sys::new_sd_ctx( @@ -67,10 +74,7 @@ impl ModelCtx { }) } - pub fn txt2img( - &self, - txt2img_config: &mut Txt2ImgConfig, - ) -> Result, DiffusionError> { + pub fn txt2img(&self, txt2img_config: &Txt2ImgConfig) -> Result, DiffusionError> { // add loras to prompt as suffix let prompt: CString = { let mut prompt = txt2img_config.prompt.clone(); @@ -84,11 +88,11 @@ impl ModelCtx { .expect("Failed to convert negative prompt to CString"); //controlnet - let control_image: *const sd_image_t = match txt2img_config.control_cond.as_mut() { + let control_image: *const sd_image_t = match txt2img_config.control_cond.as_ref() { Some(image) => { if self.config.control_net.is_file() { &sd_image_t { - data: image.as_mut_ptr(), + data: image.as_ptr().cast_mut(), width: image.width(), height: image.height(), channel: 3, @@ -124,7 +128,7 @@ impl ModelCtx { txt2img_config.style_strength, txt2img_config.normalize_input, pathbuf_to_c_char(&txt2img_config.input_id_images).as_ptr(), - txt2img_config.skip_layer.as_mut_ptr(), + txt2img_config.skip_layer.clone().as_mut_ptr(), txt2img_config.skip_layer.len(), txt2img_config.slg_scale, txt2img_config.skip_layer_start, @@ -159,391 +163,3 @@ impl Drop for ModelCtx { unsafe { diffusion_rs_sys::free_sd_ctx(self.ctx) }; } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::utils::{Data, RngFunction, SampleMethod, Schedule, WeightType}; - use crate::{model_config::ModelConfigBuilder, txt2img_config::Txt2ImgConfigBuilder}; - use diffusion_rs_sys::sd_log_level_t; - use image::ImageReader; - use std::ffi::{CStr, c_char, c_int}; - use std::sync::mpsc::{self, Sender}; - use std::sync::{Arc, Mutex}; - use std::thread; - - #[derive(Debug)] - struct State { - state: String, - } - - extern "C" fn testing_log_callback( - level: sd_log_level_t, - text: *const c_char, - _data: *mut c_void, - ) { - unsafe { - // Convert C string to Rust &str and print it. - if !text.is_null() { - let msg = CStr::from_ptr(text) - .to_str() - .unwrap_or("LOG ERROR: Invalid UTF-8"); - print!("({:?}): {}", level, msg); - } - } - } - - extern "C" fn testing_progress_callback( - step: c_int, - steps: c_int, - time: f32, - _data: *mut c_void, - ) { - // Convert C string to Rust &str and print it. - let data: &mut State = unsafe { &mut *(_data as *mut State) }; - data.state = format!("Updated at {}", time); - print!("({:?}): {} {:?}", (step, steps), time, data); - } - - extern "C" fn testing_progress_callback2( - step: c_int, - steps: c_int, - time: f32, - _data: *mut c_void, - ) { - // Convert C string to Rust &str and print it. - let data: &Sender = unsafe { &*(_data as *const Sender) }; - data.send(format!("Updated at {}", time)) - .expect("Failed to send progress update"); - print!("({:?}): {} {:?}", (step, steps), time, data); - } - - #[test] - fn test_invalid_model_config() { - let config = ModelConfigBuilder::default().build(); - assert!(config.is_err(), "ModelConfig should fail without a model"); - } - - #[test] - fn test_valid_model_config() { - let config = ModelConfigBuilder::default().model("./test.ckpt").build(); - assert!(config.is_ok(), "ModelConfig should succeed with model path"); - } - - #[test] - fn test_invalid_txt2img_config() { - let config = Txt2ImgConfigBuilder::default().build(); - assert!(config.is_err(), "Txt2ImgConfig should fail without prompt"); - } - - #[test] - fn test_valid_txt2img_config() { - let config = Txt2ImgConfigBuilder::default() - .prompt("testing prompt") - .build(); - assert!(config.is_ok(), "Txt2ImgConfig should succeed with prompt"); - } - - #[test] - fn test_model_ctx_new_invalid() { - let config = ModelConfigBuilder::default().build(); - assert!(config.is_err()); - // Attempt creating ModelCtx with error - // This is hypothetical; we expect a builder error before this - } - - #[test] - fn test_txt2img_singlethreaded_success() { - let mut state = State { - state: String::from("beginning"), - }; - - let state_ptr: Data = Data { - data: &mut state as *mut State as *mut c_void, - }; - - let model_config = ModelConfigBuilder::default() - .model("./models/mistoonAnime_v30.safetensors") - .lora_model_dir("./models/loras") - .taesd("./models/taesd1.safetensors") - .control_net("./models/controlnet/control_canny-fp16.safetensors") - .weight_type(WeightType::SD_TYPE_F16) - .rng_type(RngFunction::CUDA_RNG) - .schedule(Schedule::AYS) - .vae_decode_only(true) - .flash_attention(true) - .logging_data(state_ptr) - .progress_callback( - testing_progress_callback as extern "C" fn(i32, i32, f32, *mut libc::c_void), - ) - .build() - .expect("Failed to build model config"); - - let ctx = ModelCtx::new(&model_config).expect("Failed to build model context"); - - let resolution: i32 = 384; - let sample_steps = 3; - let control_strength = 0.8; - let control_image = ImageReader::open("./images/canny-384x.jpg") - .expect("Failed to open image") - .decode() - .expect("Failed to decode image") - .resize( - resolution as u32, - resolution as u32, - image::imageops::FilterType::Nearest, - ) - .into_rgb8(); - - let prompt = "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk"; - - let mut txt2img_config = Txt2ImgConfigBuilder::default() - .prompt(prompt) - .add_lora_model("pcm_sd15_smallcfg_2step_converted", 1.0) - .control_cond(control_image) - .control_strength(control_strength) - .sample_steps(sample_steps) - .sample_method(SampleMethod::TCD) - .eta(1.0) - .cfg_scale(1.0) - .height(resolution) - .width(resolution) - .clip_skip(2) - .batch_count(1) - .build() - .expect("Failed to build txt2img config 1"); - - let result = ctx - .txt2img(&mut txt2img_config) - .expect("Failed to generate image 1"); - - result.iter().enumerate().for_each(|(batch, img)| { - img.save(format!("./images/test_st_{}x_{}.png", resolution, batch)) - .unwrap(); - }); - } - - #[test] - fn test_txt2img_multithreaded_success() { - let (sender, receiver) = mpsc::channel(); - - let sender_ptr: Data = Data { - data: &mut sender.clone() as *mut Sender as *mut c_void, - }; - - let model_config = ModelConfigBuilder::default() - .model("./models/mistoonAnime_v30.safetensors") - .lora_model_dir("./models/loras") - .taesd("./models/taesd1.safetensors") - .control_net("./models/controlnet/control_canny-fp16.safetensors") - .weight_type(WeightType::SD_TYPE_F16) - .rng_type(RngFunction::CUDA_RNG) - .schedule(Schedule::AYS) - .vae_decode_only(true) - .flash_attention(false) - .logging_data(sender_ptr) - .progress_callback( - testing_progress_callback2 as extern "C" fn(i32, i32, f32, *mut c_void), - ) - .build() - .expect("Failed to build model config"); - - let ctx = Arc::new(Mutex::new( - ModelCtx::new(&model_config).expect("Failed to build model context"), - )); - - let resolution: i32 = 384; - let sample_steps = 3; - let control_strength = 0.8; - let control_image = ImageReader::open("./images/canny-384x.jpg") - .expect("Failed to open image") - .decode() - .expect("Failed to decode image") - .resize( - resolution as u32, - resolution as u32, - image::imageops::FilterType::Nearest, - ) - .into_rgb8(); - - let prompts = vec![ - "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", - "masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", - "masterpiece, best quality, absurdres, 1girl, medium hair, brown hair, green eyes, dark skin, dark green sweater, cat ears, nyan, sexy", - ]; - - let mut handles = vec![]; - - let mut binding = Txt2ImgConfigBuilder::default(); - let txt2img_config_base = binding - .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) - .control_cond(control_image) - .control_strength(control_strength) - .sample_steps(sample_steps) - .sample_method(SampleMethod::LCM) - .cfg_scale(1.0) - .height(resolution) - .width(resolution) - .clip_skip(2) - .batch_count(1); - - for (index, prompt) in prompts.into_iter().enumerate() { - let mut txt2img_config = txt2img_config_base - .prompt(prompt) - .build() - .expect("Failed to build txt2img config"); - - let ctx = Arc::clone(&ctx); - let sender_clone = sender.clone(); - let handle = thread::spawn(move || { - sender_clone - .send(format!("Thread {} started", index)) - .expect("Failed to send message"); - - let result = ctx - .lock() - .unwrap() - .txt2img(&mut txt2img_config) - .expect("Failed to generate image"); - - result.iter().enumerate().for_each(|(batch, img)| { - img.save(format!( - "./images/test_mt_#{}_{}x_{}.png", - index, resolution, batch - )) - .unwrap(); - }); - }); - - handles.push(handle); - } - - for handle in handles { - handle.join().unwrap(); - } - } - - #[test] - fn test_txt2img_multithreaded_multimodel_success() { - let model_config = ModelConfigBuilder::default() - .model("./models/mistoonAnime_v30.safetensors") - .lora_model_dir("./models/loras") - .taesd("./models/taesd1.safetensors") - .control_net("./models/controlnet/control_canny-fp16.safetensors") - .weight_type(WeightType::SD_TYPE_F16) - .rng_type(RngFunction::CUDA_RNG) - .schedule(Schedule::AYS) - .vae_decode_only(true) - .flash_attention(true) - .build() - .expect("Failed to build model config"); - - let ctx1 = ModelCtx::new(&model_config).expect("Failed to build model context"); - let ctx2 = ModelCtx::new(&model_config).expect("Failed to build model context"); - - let models = Arc::new(vec![ctx1, ctx2]); - - let resolution: i32 = 384; - let sample_steps = 4; - let control_strength = 0.5; - let control_image = ImageReader::open("./images/canny-384x.jpg") - .expect("Failed to open image") - .decode() - .expect("Failed to decode image") - .resize( - resolution as u32, - resolution as u32, - image::imageops::FilterType::Nearest, - ) - .into_rgb8(); - - let prompts = vec![ - "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", - //"masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", - ]; - - let mut handles = vec![]; - - let mut binding = Txt2ImgConfigBuilder::default(); - let txt2img_config_base = binding - .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) - .control_cond(control_image) - .control_strength(control_strength) - .sample_steps(sample_steps) - .sample_method(SampleMethod::LCM) - .cfg_scale(1.0) - .height(resolution) - .width(resolution) - .clip_skip(2) - .batch_count(1); - - for (index, prompt) in prompts.into_iter().enumerate() { - let mut txt2img_config = txt2img_config_base - .prompt(prompt) - .build() - .expect("Failed to build txt2img config"); - - let models = Arc::clone(&models); - let handle = thread::spawn(move || { - let result = models[index] - .txt2img(&mut txt2img_config) - .expect("Failed to generate image"); - - result.iter().enumerate().for_each(|(batch, img)| { - img.save(format!( - "./images/test_mt_mm_#{}_{}x_{}.png", - index, resolution, batch - )) - .unwrap(); - }); - println!("Thread {} finished", index); - }); - - handles.push(handle); - } - - for handle in handles { - handle.join().unwrap(); - } - } - - #[test] - fn test_txt2img_failure() { - // Build a context with invalid data to force failure - let config = ModelConfigBuilder::default() - .model("./mistoonAnime_v10Illustrious.safetensors") - .build() - .unwrap(); - let ctx = ModelCtx::new(&config).expect("Failed to build model context"); - let mut txt2img_conf = Txt2ImgConfigBuilder::default() - .prompt("test prompt") - .sample_steps(1) - .build() - .unwrap(); - // Hypothetical failure scenario - let result = ctx.txt2img(&mut txt2img_conf); - // Expect an error if calling with invalid path - // This depends on your real implementation - assert!(result.is_err() || result.is_ok()); - } - - #[test] - fn test_multiple_images() { - let config = ModelConfigBuilder::default() - .model("./mistoonAnime_v10Illustrious.safetensors") - .build() - .unwrap(); - let ctx = ModelCtx::new(&config).expect("Failed to build model context"); - let mut txt2img_conf = Txt2ImgConfigBuilder::default() - .prompt("multi-image prompt") - .sample_steps(1) - .batch_count(3) - .build() - .unwrap(); - let result = ctx.txt2img(&mut txt2img_conf); - assert!(result.is_ok()); - if let Ok(images) = result { - assert_eq!(images.len(), 3); - } - } -} diff --git a/src/lib.rs b/src/lib.rs index fdec561..2418856 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,4 +4,5 @@ pub mod api; pub mod model_config; pub mod txt2img_config; -pub mod utils; +pub mod types; +mod utils; diff --git a/src/model_config.rs b/src/model_config.rs index 0e1012d..c17c636 100644 --- a/src/model_config.rs +++ b/src/model_config.rs @@ -1,12 +1,19 @@ +use crate::types::{RngFunction, Schedule, SdLogLevel, WeightType}; use derive_builder::Builder; -use diffusion_rs_sys::{get_num_physical_cores, sd_log_level_t}; -use std::ffi::{c_char, c_void}; +use std::ffi::{CStr, c_char, c_void}; use std::path::PathBuf; -use crate::utils::{Data, RngFunction, Schedule, WeightType}; +type UnsafeLogCallbackFn = unsafe extern "C" fn( + level: diffusion_rs_sys::sd_log_level_t, + text: *const c_char, + data: *mut c_void, +); + +type UnsafeProgressCallbackFn = + unsafe extern "C" fn(step: i32, steps: i32, time: f32, data: *mut c_void); #[derive(Builder, Debug, Clone)] -#[builder(setter(into, strip_option), build_fn(validate = "Self::validate"))] +#[builder(setter(into), build_fn(validate = "Self::validate"))] /// Config struct common to all diffusion methods pub struct ModelConfig { /// Path to full model @@ -67,8 +74,11 @@ pub struct ModelConfig { pub free_params_immediately: bool, /// Number of threads to use during computation (default: 0). - /// If n_ threads <= 0, then threads will be set to the number of CPU physical cores. - #[builder(default = "unsafe { get_num_physical_cores() }", setter(custom))] + /// If n_threads <= 0, then threads will be set to the number of CPU physical cores. + #[builder( + default = "unsafe { diffusion_rs_sys::get_num_physical_cores() }", + setter(custom) + )] pub n_threads: i32, /// Weight type. If not specified, the default is the type of the weight file @@ -104,17 +114,12 @@ pub struct ModelConfig { pub flash_attention: bool, /// set log callback function for cpp logs (default: None) - #[builder(default = "None")] - pub log_callback: - Option, + #[builder(setter(custom))] + pub log_callback: (Option, *mut c_void), /// set log callback function for progress logs (default: None) - #[builder(default = "None")] - pub progress_callback: - Option, - - #[builder(default = "None")] - pub logging_data: Option, + #[builder(setter(custom))] + pub progress_callback: (Option, *mut c_void), } impl ModelConfigBuilder { @@ -122,7 +127,7 @@ impl ModelConfigBuilder { self.n_threads = if value > 0 { Some(value) } else { - Some(unsafe { get_num_physical_cores() }) + Some(unsafe { diffusion_rs_sys::get_num_physical_cores() }) }; self } @@ -140,4 +145,50 @@ impl ModelConfigBuilder { "Model OR DiffusionModel must be initialized", )) } + + pub fn log_callback(&mut self, mut closure: F) -> &mut Self + where + F: FnMut(SdLogLevel, String) + Send + Sync, + { + let mut unsafe_closure = |level: diffusion_rs_sys::sd_log_level_t, text: *const c_char| { + let msg = unsafe { CStr::from_ptr(text) } + .to_str() + .unwrap_or("LOG ERROR: Invalid UTF-8"); + let level = SdLogLevel::from(level); + (closure)(level, msg.to_string()); + }; + + let (state, callback) = + unsafe { ffi_helpers::split_closure_trailing_data(&mut unsafe_closure) }; + let adapted_callback: UnsafeLogCallbackFn = callback; + self.log_callback = Some((Some(adapted_callback), state)); + self + } + + pub fn progress_callback(&mut self, mut closure: F) -> &mut Self + where + F: FnMut(i32, i32, f32) + Send + Sync, + { + let (state, callback) = unsafe { ffi_helpers::split_closure_trailing_data(&mut closure) }; + let adapted_callback: UnsafeProgressCallbackFn = callback; + self.progress_callback = Some((Some(adapted_callback), state)); + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_invalid_model_config() { + let config = ModelConfigBuilder::default().build(); + assert!(config.is_err(), "ModelConfig should fail without a model"); + } + + #[test] + fn test_valid_model_config() { + let config = ModelConfigBuilder::default().model("./test.ckpt").build(); + assert!(config.is_ok(), "ModelConfig should succeed with model path"); + } } diff --git a/src/txt2img_config.rs b/src/txt2img_config.rs index f508093..e879f5f 100644 --- a/src/txt2img_config.rs +++ b/src/txt2img_config.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use crate::utils::SampleMethod; +use crate::types::SampleMethod; use derive_builder::Builder; use image::RgbImage; @@ -117,3 +117,22 @@ impl Txt2ImgConfigBuilder { self } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_invalid_txt2img_config() { + let config = Txt2ImgConfigBuilder::default().build(); + assert!(config.is_err(), "Txt2ImgConfig should fail without prompt"); + } + + #[test] + fn test_valid_txt2img_config() { + let config = Txt2ImgConfigBuilder::default() + .prompt("testing prompt") + .build(); + assert!(config.is_ok(), "Txt2ImgConfig should succeed with prompt"); + } +} diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 0000000..c1eeb46 --- /dev/null +++ b/src/types.rs @@ -0,0 +1,41 @@ +/// Specify the range function +pub use diffusion_rs_sys::rng_type_t as RngFunction; + +/// Denoiser sigma schedule +pub use diffusion_rs_sys::schedule_t as Schedule; + +/// Weight type +pub use diffusion_rs_sys::sd_type_t as WeightType; + +/// Sampling methods +pub use diffusion_rs_sys::sample_method_t as SampleMethod; + +/// Log Level +pub use diffusion_rs_sys::sd_log_level_t as SdLogLevel; + +#[non_exhaustive] +#[derive(thiserror::Error, Debug)] +/// Error that can occurs while forwarding models +pub enum DiffusionError { + #[error("The underling stablediffusion.cpp function returned NULL")] + Forward, + #[error("The underling stbi_write_image function returned 0 while saving image {0}/{1})")] + StoreImages(usize, i32), + #[error("The underling upscaler model returned a NULL image")] + Upscaler, + #[error("raw_ctx is None")] + NoContext, + #[error("new_sd_ctx returned null")] + NewContextFailure, + #[error("SD image conversion error: {0}")] + SDImageError(#[from] SDImageError), +} + +#[non_exhaustive] +#[derive(Debug, thiserror::Error)] +pub enum SDImageError { + #[error("Failed to convert image buffer to Rust type")] + AllocationError, + #[error("The image buffer has a different length than expected")] + DifferentLength, +} diff --git a/src/utils.rs b/src/utils.rs index fd47188..e91646f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,66 +1,18 @@ -use diffusion_rs_sys::sd_log_level_t; -use diffusion_rs_sys::sd_set_log_callback; -use diffusion_rs_sys::sd_set_progress_callback; use image::ImageBuffer; use image::Rgb; use image::RgbImage; -use std::ffi::CStr; use std::ffi::CString; -use std::ffi::c_char; -use std::ffi::c_void; use std::path::PathBuf; use std::slice; -use thiserror::Error; - -#[non_exhaustive] -#[derive(Error, Debug)] -/// Error that can occurs while forwarding models -pub enum DiffusionError { - #[error("The underling stablediffusion.cpp function returned NULL")] - Forward, - #[error("The underling stbi_write_image function returned 0 while saving image {0}/{1})")] - StoreImages(usize, i32), - #[error("The underling upscaler model returned a NULL image")] - Upscaler, - #[error("raw_ctx is None")] - NoContext, - #[error("new_sd_ctx returned null")] - NewContextFailure, - #[error("SD image conversion error: {0}")] - SDImageError(#[from] SDImageError), -} - -#[non_exhaustive] -#[derive(Debug, thiserror::Error)] -pub enum SDImageError { - #[error("Failed to convert image buffer to Rust type")] - AllocationError, - #[error("The image buffer has a different length than expected")] - DifferentLength, -} - -/// Specify the range function -pub use diffusion_rs_sys::rng_type_t as RngFunction; - -/// Denoiser sigma schedule -pub use diffusion_rs_sys::schedule_t as Schedule; - -/// Weight type -pub use diffusion_rs_sys::sd_type_t as WeightType; - -/// Sampling methods -pub use diffusion_rs_sys::sample_method_t as SampleMethod; - -//log level -pub use diffusion_rs_sys::sd_log_level_t as SdLogLevel; use diffusion_rs_sys::sd_image_t; +use crate::types::SDImageError; + pub fn pathbuf_to_c_char(path: &PathBuf) -> CString { let path_str = path .to_str() .expect("PathBuf contained non-UTF-8 characters"); - // Create a CString which adds a null terminator. CString::new(path_str).expect("CString conversion failed") } @@ -75,44 +27,3 @@ pub fn convert_image(sd_image: &sd_image_t) -> Result { None => return Err(SDImageError::AllocationError), }) } - -extern "C" fn default_log_callback(level: sd_log_level_t, text: *const c_char, _data: *mut c_void) { - unsafe { - // Convert C string to Rust &str and print it. - if !text.is_null() { - let msg = CStr::from_ptr(text) - .to_str() - .unwrap_or("LOG ERROR: Invalid UTF-8"); - print!("({:?}): {}", level, msg); - } - } -} - -type UnsafeLogCallbackFn = - extern "C" fn(level: sd_log_level_t, text: *const c_char, data: *mut c_void); - -type UnsafeProgressCallbackFn = extern "C" fn(step: i32, steps: i32, time: f32, data: *mut c_void); - -pub fn setup_logging( - log_callback: Option, - progress_callback: Option, - logging_data: *mut c_void, -) { - unsafe { - match log_callback { - Some(callback) => sd_set_log_callback(Some(callback), logging_data), - None => sd_set_log_callback(Some(default_log_callback), logging_data), - }; - match progress_callback { - Some(callback) => sd_set_progress_callback(Some(callback), logging_data), - None => (), - }; - } -} - -#[derive(Debug, Clone)] -pub struct Data { - pub data: *mut c_void, -} - -unsafe impl Send for Data {} diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs new file mode 100644 index 0000000..58c2e3d --- /dev/null +++ b/tests/integration_tests.rs @@ -0,0 +1,278 @@ +use diffusion_rs::api::ModelCtx; +use diffusion_rs::model_config::ModelConfigBuilder; +use diffusion_rs::txt2img_config::Txt2ImgConfigBuilder; +use diffusion_rs::types::{RngFunction, SampleMethod, Schedule, WeightType}; +use image::ImageReader; + +use std::sync::{Arc, Mutex}; +use std::thread; + +#[test] +fn test_txt2img_singlethreaded_success() { + let model_config = ModelConfigBuilder::default() + .model("./models/mistoonAnime_v30.safetensors") + .lora_model_dir("./models/loras") + .taesd("./models/taesd1.safetensors") + .control_net("./models/controlnet/control_canny-fp16.safetensors") + .schedule(Schedule::AYS) + .vae_decode_only(true) + .flash_attention(true) + .log_callback(|level, text| { + print!("({:?}): {}", level, text); + }) + .progress_callback(|step, steps, time| { + println!("Progress: {}/{} ({}s)", step, steps, time); + }) + .build() + .expect("Failed to build model config"); + + let ctx = ModelCtx::new(&model_config).expect("Failed to build model context"); + + let resolution: i32 = 384; + let sample_steps = 3; + let control_strength = 0.8; + let control_image = ImageReader::open("./images/canny-384x.jpg") + .expect("Failed to open image") + .decode() + .expect("Failed to decode image") + .resize( + resolution as u32, + resolution as u32, + image::imageops::FilterType::Nearest, + ) + .into_rgb8(); + + let prompt = "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk"; + + let txt2img_config = Txt2ImgConfigBuilder::default() + .prompt(prompt) + .add_lora_model("pcm_sd15_smallcfg_2step_converted", 1.0) + .control_cond(control_image) + .control_strength(control_strength) + .sample_steps(sample_steps) + .sample_method(SampleMethod::TCD) + .eta(1.0) + .cfg_scale(1.0) + .height(resolution) + .width(resolution) + .clip_skip(2) + .batch_count(1) + .build() + .expect("Failed to build txt2img config 1"); + + let result = ctx + .txt2img(&txt2img_config) + .expect("Failed to generate image 1"); + + result.iter().enumerate().for_each(|(batch, img)| { + img.save(format!("./images/test_st_{}x_{}.png", resolution, batch)) + .unwrap(); + }); +} + +#[test] +fn test_txt2img_multithreaded_success() { + let model_config = ModelConfigBuilder::default() + .model("./models/mistoonAnime_v30.safetensors") + .lora_model_dir("./models/loras") + .taesd("./models/taesd1.safetensors") + .control_net("./models/controlnet/control_canny-fp16.safetensors") + .weight_type(WeightType::SD_TYPE_F16) + .rng_type(RngFunction::CUDA_RNG) + .schedule(Schedule::AYS) + .vae_decode_only(true) + .flash_attention(false) + .build() + .expect("Failed to build model config"); + + let ctx = Arc::new(Mutex::new( + ModelCtx::new(&model_config).expect("Failed to build model context"), + )); + + let resolution: i32 = 384; + let sample_steps = 3; + let control_strength = 0.8; + let control_image = ImageReader::open("./images/canny-384x.jpg") + .expect("Failed to open image") + .decode() + .expect("Failed to decode image") + .resize( + resolution as u32, + resolution as u32, + image::imageops::FilterType::Nearest, + ) + .into_rgb8(); + + let prompts = vec![ + "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", + "masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", + "masterpiece, best quality, absurdres, 1girl, medium hair, brown hair, green eyes, dark skin, dark green sweater, cat ears, nyan, sexy", + ]; + + let mut handles = vec![]; + + let mut builder = Txt2ImgConfigBuilder::default(); + + builder + .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) + .control_cond(control_image) + .control_strength(control_strength) + .sample_steps(sample_steps) + .sample_method(SampleMethod::LCM) + .cfg_scale(1.0) + .height(resolution) + .width(resolution) + .clip_skip(2) + .batch_count(1); + + for (index, prompt) in prompts.into_iter().enumerate() { + builder.prompt(prompt); + let txt2img_config = builder.build().expect("Failed to build txt2img config"); + + let ctx = Arc::clone(&ctx); + + let handle = thread::spawn(move || { + let result = ctx + .lock() + .unwrap() + .txt2img(&txt2img_config) + .expect("Failed to generate image"); + + result.iter().enumerate().for_each(|(batch, img)| { + img.save(format!( + "./images/test_mt_#{}_{}x_{}.png", + index, resolution, batch + )) + .unwrap(); + }); + }); + + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } +} + +#[test] +fn test_txt2img_multithreaded_multimodel_success() { + let model_config = ModelConfigBuilder::default() + .model("./models/mistoonAnime_v30.safetensors") + .lora_model_dir("./models/loras") + .taesd("./models/taesd1.safetensors") + .control_net("./models/controlnet/control_canny-fp16.safetensors") + .schedule(Schedule::AYS) + .vae_decode_only(true) + .flash_attention(true) + .build() + .expect("Failed to build model config"); + + let ctx1 = ModelCtx::new(&model_config).expect("Failed to build model context"); + let ctx2 = ModelCtx::new(&model_config).expect("Failed to build model context"); + + let models = Arc::new(vec![ctx1, ctx2]); + + let resolution: i32 = 384; + let sample_steps = 3; + let control_strength = 0.5; + let control_image = ImageReader::open("./images/canny-384x.jpg") + .expect("Failed to open image") + .decode() + .expect("Failed to decode image") + .resize( + resolution as u32, + resolution as u32, + image::imageops::FilterType::Nearest, + ) + .into_rgb8(); + + let prompts = vec![ + "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", + "masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", + ]; + + let mut handles = vec![]; + + let mut binding = Txt2ImgConfigBuilder::default(); + let txt2img_config_base = binding + .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) + .control_cond(control_image) + .control_strength(control_strength) + .sample_steps(sample_steps) + .sample_method(SampleMethod::LCM) + .cfg_scale(1.0) + .height(resolution) + .width(resolution) + .clip_skip(2) + .batch_count(1); + + for (index, prompt) in prompts.into_iter().enumerate() { + let txt2img_config = txt2img_config_base + .prompt(prompt) + .build() + .expect("Failed to build txt2img config"); + + let models = models.clone(); + let handle = thread::spawn(move || { + let result = models[index] + .txt2img(&txt2img_config) + .expect("Failed to generate image"); + + result.iter().enumerate().for_each(|(batch, img)| { + img.save(format!( + "./images/test_mt_mm_#{}_{}x_{}.png", + index, resolution, batch + )) + .unwrap(); + }); + println!("Thread {} finished", index); + }); + + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } +} + +#[test] +fn test_txt2img_failure() { + // Build a context with invalid data to force failure + let config = ModelConfigBuilder::default() + .model("./mistoonAnime_v10Illustrious.safetensors") + .build() + .unwrap(); + let ctx = ModelCtx::new(&config).expect("Failed to build model context"); + let txt2img_conf = Txt2ImgConfigBuilder::default() + .prompt("test prompt") + .sample_steps(1) + .build() + .unwrap(); + // Hypothetical failure scenario + let result = ctx.txt2img(&txt2img_conf); + // Expect an error if calling with invalid path + // This depends on your real implementation + assert!(result.is_err() || result.is_ok()); +} + +#[test] +fn test_multiple_images() { + let config = ModelConfigBuilder::default() + .model("./mistoonAnime_v10Illustrious.safetensors") + .build() + .unwrap(); + let ctx = ModelCtx::new(&config).expect("Failed to build model context"); + let txt2img_conf = Txt2ImgConfigBuilder::default() + .prompt("multi-image prompt") + .sample_steps(1) + .batch_count(3) + .build() + .unwrap(); + let result = ctx.txt2img(&txt2img_conf); + assert!(result.is_ok()); + if let Ok(images) = result { + assert_eq!(images.len(), 3); + } +} From 1101bcbe106beb5cc658a8a073799d7217165101 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Fri, 28 Mar 2025 14:11:26 -0400 Subject: [PATCH 31/33] multithreading, kinda working? --- src/api.rs | 31 ++++++++++++++++++++--------- src/model_config.rs | 15 ++++++++------ tests/integration_tests.rs | 40 ++++++++++++++++++++++++++++++-------- 3 files changed, 63 insertions(+), 23 deletions(-) diff --git a/src/api.rs b/src/api.rs index c937fc4..ee59518 100644 --- a/src/api.rs +++ b/src/api.rs @@ -25,16 +25,29 @@ unsafe impl Sync for ModelCtx {} impl ModelCtx { pub fn new(config: &ModelConfig) -> Result { - unsafe { - diffusion_rs_sys::sd_set_log_callback(config.log_callback.0, config.log_callback.1) - }; + match &config.log_callback { + Some(t) => { + unsafe { + diffusion_rs_sys::sd_set_log_callback( + t.0, + t.1.clone().lock().unwrap().as_mut().unwrap(), + ) + }; + } + None => {} + } - unsafe { - diffusion_rs_sys::sd_set_progress_callback( - config.progress_callback.0, - config.progress_callback.1, - ) - }; + match &config.progress_callback { + Some(t) => { + unsafe { + diffusion_rs_sys::sd_set_progress_callback( + t.0, + t.1.clone().lock().unwrap().as_mut().unwrap(), + ) + }; + } + None => {} + } let ctx = unsafe { let ptr = diffusion_rs_sys::new_sd_ctx( diff --git a/src/model_config.rs b/src/model_config.rs index c17c636..10fec69 100644 --- a/src/model_config.rs +++ b/src/model_config.rs @@ -2,6 +2,7 @@ use crate::types::{RngFunction, Schedule, SdLogLevel, WeightType}; use derive_builder::Builder; use std::ffi::{CStr, c_char, c_void}; use std::path::PathBuf; +use std::sync::{Arc, Mutex}; type UnsafeLogCallbackFn = unsafe extern "C" fn( level: diffusion_rs_sys::sd_log_level_t, @@ -114,14 +115,16 @@ pub struct ModelConfig { pub flash_attention: bool, /// set log callback function for cpp logs (default: None) - #[builder(setter(custom))] - pub log_callback: (Option, *mut c_void), + #[builder(setter(custom, strip_option), default = "None")] + pub log_callback: Option<(Option, Arc>)>, /// set log callback function for progress logs (default: None) - #[builder(setter(custom))] - pub progress_callback: (Option, *mut c_void), + #[builder(setter(custom, strip_option), default = "None")] + pub progress_callback: Option<(Option, Arc>)>, } +unsafe impl Send for ModelConfig {} + impl ModelConfigBuilder { pub fn n_threads(&mut self, value: i32) -> &mut Self { self.n_threads = if value > 0 { @@ -161,7 +164,7 @@ impl ModelConfigBuilder { let (state, callback) = unsafe { ffi_helpers::split_closure_trailing_data(&mut unsafe_closure) }; let adapted_callback: UnsafeLogCallbackFn = callback; - self.log_callback = Some((Some(adapted_callback), state)); + self.log_callback = Some(Some((Some(adapted_callback), Arc::new(Mutex::new(state))))); self } @@ -171,7 +174,7 @@ impl ModelConfigBuilder { { let (state, callback) = unsafe { ffi_helpers::split_closure_trailing_data(&mut closure) }; let adapted_callback: UnsafeProgressCallbackFn = callback; - self.progress_callback = Some((Some(adapted_callback), state)); + self.progress_callback = Some(Some((Some(adapted_callback), Arc::new(Mutex::new(state))))); self } } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 58c2e3d..c8a26fb 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -82,6 +82,12 @@ fn test_txt2img_multithreaded_success() { .schedule(Schedule::AYS) .vae_decode_only(true) .flash_attention(false) + .log_callback(|level, text| { + print!("({:?}): {}", level, text); + }) + .progress_callback(|step, steps, time| { + println!("Progress: {}/{} ({}s)", step, steps, time); + }) .build() .expect("Failed to build model config"); @@ -157,24 +163,42 @@ fn test_txt2img_multithreaded_success() { #[test] fn test_txt2img_multithreaded_multimodel_success() { - let model_config = ModelConfigBuilder::default() + let mut model_config = ModelConfigBuilder::default(); + model_config .model("./models/mistoonAnime_v30.safetensors") .lora_model_dir("./models/loras") .taesd("./models/taesd1.safetensors") .control_net("./models/controlnet/control_canny-fp16.safetensors") .schedule(Schedule::AYS) .vae_decode_only(true) - .flash_attention(true) - .build() - .expect("Failed to build model config"); + .flash_attention(true); + + let mut model_handle = vec![]; + for x in 0..2 { + let model_config = model_config + .log_callback(|level, text| { + print!("[Thread {}], ({:?}): {}", x, level, text); + }) + .build() + .expect("Failed to build model config"); + let handle = thread::spawn(move || { + // Use the context directly in the thread + return ModelCtx::new(&model_config).expect("Failed to build model context"); + }); + model_handle.push(handle); + } - let ctx1 = ModelCtx::new(&model_config).expect("Failed to build model context"); - let ctx2 = ModelCtx::new(&model_config).expect("Failed to build model context"); + // wait for threads to finish + let mut models = vec![]; + for handle in model_handle { + let ctx = handle.join().expect("Failed to join thread"); + models.push(ctx); + } - let models = Arc::new(vec![ctx1, ctx2]); + let models = Arc::new(models); let resolution: i32 = 384; - let sample_steps = 3; + let sample_steps = 1; let control_strength = 0.5; let control_image = ImageReader::open("./images/canny-384x.jpg") .expect("Failed to open image") From f60ad5d7f286aa4e945f7bea13e0c8ed9a63e37d Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Sun, 30 Mar 2025 23:53:12 -0400 Subject: [PATCH 32/33] some other changes to fix callbacks --- src/api.rs | 15 ++----- src/model_config.rs | 53 ++++++++++++++++++------ tests/integration_tests.rs | 84 +++++++++++++++++++------------------- 3 files changed, 85 insertions(+), 67 deletions(-) diff --git a/src/api.rs b/src/api.rs index ee59518..a16433d 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,6 +1,7 @@ use std::ffi::{CString, c_void}; use std::ptr::null; use std::slice; +use std::sync::Arc; use crate::model_config::ModelConfig; use crate::txt2img_config::Txt2ImgConfig; @@ -27,24 +28,14 @@ impl ModelCtx { pub fn new(config: &ModelConfig) -> Result { match &config.log_callback { Some(t) => { - unsafe { - diffusion_rs_sys::sd_set_log_callback( - t.0, - t.1.clone().lock().unwrap().as_mut().unwrap(), - ) - }; + unsafe { diffusion_rs_sys::sd_set_log_callback(t.callback, t.data) }; } None => {} } match &config.progress_callback { Some(t) => { - unsafe { - diffusion_rs_sys::sd_set_progress_callback( - t.0, - t.1.clone().lock().unwrap().as_mut().unwrap(), - ) - }; + unsafe { diffusion_rs_sys::sd_set_progress_callback(t.callback, t.data) }; } None => {} } diff --git a/src/model_config.rs b/src/model_config.rs index 10fec69..c23069c 100644 --- a/src/model_config.rs +++ b/src/model_config.rs @@ -2,7 +2,8 @@ use crate::types::{RngFunction, Schedule, SdLogLevel, WeightType}; use derive_builder::Builder; use std::ffi::{CStr, c_char, c_void}; use std::path::PathBuf; -use std::sync::{Arc, Mutex}; +use std::ptr; +use std::sync::Arc; type UnsafeLogCallbackFn = unsafe extern "C" fn( level: diffusion_rs_sys::sd_log_level_t, @@ -13,6 +14,24 @@ type UnsafeLogCallbackFn = unsafe extern "C" fn( type UnsafeProgressCallbackFn = unsafe extern "C" fn(step: i32, steps: i32, time: f32, data: *mut c_void); +#[derive(Debug, Clone)] +pub struct LogCallBackWrapper { + pub callback: Option, + pub data: *mut c_void, +} + +unsafe impl Send for LogCallBackWrapper {} +unsafe impl Sync for LogCallBackWrapper {} + +#[derive(Debug, Clone)] +pub struct ProgressCallBackWrapper { + pub callback: Option, + pub data: *mut c_void, +} + +unsafe impl Send for ProgressCallBackWrapper {} +unsafe impl Sync for ProgressCallBackWrapper {} + #[derive(Builder, Debug, Clone)] #[builder(setter(into), build_fn(validate = "Self::validate"))] /// Config struct common to all diffusion methods @@ -115,12 +134,12 @@ pub struct ModelConfig { pub flash_attention: bool, /// set log callback function for cpp logs (default: None) - #[builder(setter(custom, strip_option), default = "None")] - pub log_callback: Option<(Option, Arc>)>, + #[builder(setter(custom), default = "None")] + pub log_callback: Option, /// set log callback function for progress logs (default: None) - #[builder(setter(custom, strip_option), default = "None")] - pub progress_callback: Option<(Option, Arc>)>, + #[builder(setter(custom), default = "None")] + pub progress_callback: Option, } unsafe impl Send for ModelConfig {} @@ -151,9 +170,9 @@ impl ModelConfigBuilder { pub fn log_callback(&mut self, mut closure: F) -> &mut Self where - F: FnMut(SdLogLevel, String) + Send + Sync, + F: FnMut(SdLogLevel, String) + Send + Sync + 'static, { - let mut unsafe_closure = |level: diffusion_rs_sys::sd_log_level_t, text: *const c_char| { + let unsafe_closure = move |level: diffusion_rs_sys::sd_log_level_t, text: *const c_char| { let msg = unsafe { CStr::from_ptr(text) } .to_str() .unwrap_or("LOG ERROR: Invalid UTF-8"); @@ -161,20 +180,28 @@ impl ModelConfigBuilder { (closure)(level, msg.to_string()); }; + let boxed_closure = Box::new(unsafe_closure); + let (state, callback) = - unsafe { ffi_helpers::split_closure_trailing_data(&mut unsafe_closure) }; - let adapted_callback: UnsafeLogCallbackFn = callback; - self.log_callback = Some(Some((Some(adapted_callback), Arc::new(Mutex::new(state))))); + unsafe { ffi_helpers::split_closure_trailing_data(Box::leak(boxed_closure)) }; + + self.log_callback = Some(Some(LogCallBackWrapper { + callback: Some(callback), + data: state, + })); self } pub fn progress_callback(&mut self, mut closure: F) -> &mut Self where - F: FnMut(i32, i32, f32) + Send + Sync, + F: FnMut(i32, i32, f32) + Send + Sync + 'static, { let (state, callback) = unsafe { ffi_helpers::split_closure_trailing_data(&mut closure) }; - let adapted_callback: UnsafeProgressCallbackFn = callback; - self.progress_callback = Some(Some((Some(adapted_callback), Arc::new(Mutex::new(state))))); + + self.progress_callback = Some(Some(ProgressCallBackWrapper { + callback: Some(callback), + data: state, + })); self } } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index c8a26fb..7ae8c8c 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -7,6 +7,46 @@ use image::ImageReader; use std::sync::{Arc, Mutex}; use std::thread; +#[test] +fn test_txt2img_failure() { + // Build a context with invalid data to force failure + let config = ModelConfigBuilder::default() + .model("./mistoonAnime_v10Illustrious.safetensors") + .build() + .unwrap(); + let ctx = ModelCtx::new(&config).expect("Failed to build model context"); + let txt2img_conf = Txt2ImgConfigBuilder::default() + .prompt("test prompt") + .sample_steps(1) + .build() + .unwrap(); + // Hypothetical failure scenario + let result = ctx.txt2img(&txt2img_conf); + // Expect an error if calling with invalid path + // This depends on your real implementation + assert!(result.is_err() || result.is_ok()); +} + +#[test] +fn test_multiple_images() { + let config = ModelConfigBuilder::default() + .model("./mistoonAnime_v10Illustrious.safetensors") + .build() + .unwrap(); + let ctx = ModelCtx::new(&config).expect("Failed to build model context"); + let txt2img_conf = Txt2ImgConfigBuilder::default() + .prompt("multi-image prompt") + .sample_steps(1) + .batch_count(3) + .build() + .unwrap(); + let result = ctx.txt2img(&txt2img_conf); + assert!(result.is_ok()); + if let Ok(images) = result { + assert_eq!(images.len(), 3); + } +} + #[test] fn test_txt2img_singlethreaded_success() { let model_config = ModelConfigBuilder::default() @@ -176,8 +216,8 @@ fn test_txt2img_multithreaded_multimodel_success() { let mut model_handle = vec![]; for x in 0..2 { let model_config = model_config - .log_callback(|level, text| { - print!("[Thread {}], ({:?}): {}", x, level, text); + .log_callback(move |level, text| { + print!("[Thread {}] ({:?}): {}", x, level, text); }) .build() .expect("Failed to build model config"); @@ -260,43 +300,3 @@ fn test_txt2img_multithreaded_multimodel_success() { handle.join().unwrap(); } } - -#[test] -fn test_txt2img_failure() { - // Build a context with invalid data to force failure - let config = ModelConfigBuilder::default() - .model("./mistoonAnime_v10Illustrious.safetensors") - .build() - .unwrap(); - let ctx = ModelCtx::new(&config).expect("Failed to build model context"); - let txt2img_conf = Txt2ImgConfigBuilder::default() - .prompt("test prompt") - .sample_steps(1) - .build() - .unwrap(); - // Hypothetical failure scenario - let result = ctx.txt2img(&txt2img_conf); - // Expect an error if calling with invalid path - // This depends on your real implementation - assert!(result.is_err() || result.is_ok()); -} - -#[test] -fn test_multiple_images() { - let config = ModelConfigBuilder::default() - .model("./mistoonAnime_v10Illustrious.safetensors") - .build() - .unwrap(); - let ctx = ModelCtx::new(&config).expect("Failed to build model context"); - let txt2img_conf = Txt2ImgConfigBuilder::default() - .prompt("multi-image prompt") - .sample_steps(1) - .batch_count(3) - .build() - .unwrap(); - let result = ctx.txt2img(&txt2img_conf); - assert!(result.is_ok()); - if let Ok(images) = result { - assert_eq!(images.len(), 3); - } -} From 984476b490d2663b7cfb9ed8bd67e2a38fbd5174 Mon Sep 17 00:00:00 2001 From: Brandon Wand Date: Fri, 4 Apr 2025 18:43:53 -0400 Subject: [PATCH 33/33] fixing the bad closure code due to non specific model context logging... --- Cargo.toml | 1 - src/api.rs | 38 +++--- src/model_config.rs | 80 +----------- src/types.rs | 120 ++++++++++++++++++ tests/integration_tests.rs | 242 ++++++++++++++++++++----------------- 5 files changed, 274 insertions(+), 207 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 275ba71..cd25fea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,6 @@ documentation = "https://docs.rs/diffusion-rs" [dependencies] derive_builder = "0.20.2" diffusion-rs-sys = { path = "sys", version = "0.1.8" } -ffi_helpers = { git = "https://github.com/anzbert/ffi_helpers.git", branch = "master" } image = "0.25.6" libc = "0.2.171" num_cpus = "1.16.0" diff --git a/src/api.rs b/src/api.rs index a16433d..8dd47f6 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,11 +1,11 @@ use std::ffi::{CString, c_void}; +use std::mem::ManuallyDrop; use std::ptr::null; use std::slice; -use std::sync::Arc; use crate::model_config::ModelConfig; use crate::txt2img_config::Txt2ImgConfig; -use crate::types::DiffusionError; +use crate::types::{DiffusionError, LogCallback, ProgressCallback, SdLogLevel}; use crate::utils::{convert_image, pathbuf_to_c_char}; use diffusion_rs_sys::sd_image_t; @@ -22,24 +22,10 @@ pub struct ModelCtx { } unsafe impl Send for ModelCtx {} -unsafe impl Sync for ModelCtx {} +// unsafe impl Sync for ModelCtx {} impl ModelCtx { pub fn new(config: &ModelConfig) -> Result { - match &config.log_callback { - Some(t) => { - unsafe { diffusion_rs_sys::sd_set_log_callback(t.callback, t.data) }; - } - None => {} - } - - match &config.progress_callback { - Some(t) => { - unsafe { diffusion_rs_sys::sd_set_progress_callback(t.callback, t.data) }; - } - None => {} - } - let ctx = unsafe { let ptr = diffusion_rs_sys::new_sd_ctx( pathbuf_to_c_char(&config.model).as_ptr(), @@ -78,6 +64,24 @@ impl ModelCtx { }) } + pub fn set_log_callback(on_log: F) -> () + where + F: Fn(SdLogLevel, String) + Send + Sync + 'static, + { + // Create a new log callback + let t = ManuallyDrop::new(LogCallback::new(on_log)); + unsafe { diffusion_rs_sys::sd_set_log_callback(t.callback(), t.user_data()) }; + } + + pub fn set_progress_callback(on_progress: F) -> () + where + F: Fn(i32, i32, f32) + Send + Sync + 'static, + { + // Create a new progress callback + let t = ManuallyDrop::new(ProgressCallback::new(on_progress)); + unsafe { diffusion_rs_sys::sd_set_progress_callback(t.callback(), t.user_data()) }; + } + pub fn txt2img(&self, txt2img_config: &Txt2ImgConfig) -> Result, DiffusionError> { // add loras to prompt as suffix let prompt: CString = { diff --git a/src/model_config.rs b/src/model_config.rs index c23069c..fc77a27 100644 --- a/src/model_config.rs +++ b/src/model_config.rs @@ -1,38 +1,9 @@ -use crate::types::{RngFunction, Schedule, SdLogLevel, WeightType}; +use crate::types::{RngFunction, Schedule, WeightType}; use derive_builder::Builder; -use std::ffi::{CStr, c_char, c_void}; -use std::path::PathBuf; -use std::ptr; -use std::sync::Arc; - -type UnsafeLogCallbackFn = unsafe extern "C" fn( - level: diffusion_rs_sys::sd_log_level_t, - text: *const c_char, - data: *mut c_void, -); - -type UnsafeProgressCallbackFn = - unsafe extern "C" fn(step: i32, steps: i32, time: f32, data: *mut c_void); - -#[derive(Debug, Clone)] -pub struct LogCallBackWrapper { - pub callback: Option, - pub data: *mut c_void, -} - -unsafe impl Send for LogCallBackWrapper {} -unsafe impl Sync for LogCallBackWrapper {} - -#[derive(Debug, Clone)] -pub struct ProgressCallBackWrapper { - pub callback: Option, - pub data: *mut c_void, -} -unsafe impl Send for ProgressCallBackWrapper {} -unsafe impl Sync for ProgressCallBackWrapper {} +use std::path::PathBuf; -#[derive(Builder, Debug, Clone)] +#[derive(Builder, Clone, Debug)] #[builder(setter(into), build_fn(validate = "Self::validate"))] /// Config struct common to all diffusion methods pub struct ModelConfig { @@ -132,14 +103,6 @@ pub struct ModelConfig { /// (default: false) #[builder(default = "false")] pub flash_attention: bool, - - /// set log callback function for cpp logs (default: None) - #[builder(setter(custom), default = "None")] - pub log_callback: Option, - - /// set log callback function for progress logs (default: None) - #[builder(setter(custom), default = "None")] - pub progress_callback: Option, } unsafe impl Send for ModelConfig {} @@ -167,43 +130,6 @@ impl ModelConfigBuilder { "Model OR DiffusionModel must be initialized", )) } - - pub fn log_callback(&mut self, mut closure: F) -> &mut Self - where - F: FnMut(SdLogLevel, String) + Send + Sync + 'static, - { - let unsafe_closure = move |level: diffusion_rs_sys::sd_log_level_t, text: *const c_char| { - let msg = unsafe { CStr::from_ptr(text) } - .to_str() - .unwrap_or("LOG ERROR: Invalid UTF-8"); - let level = SdLogLevel::from(level); - (closure)(level, msg.to_string()); - }; - - let boxed_closure = Box::new(unsafe_closure); - - let (state, callback) = - unsafe { ffi_helpers::split_closure_trailing_data(Box::leak(boxed_closure)) }; - - self.log_callback = Some(Some(LogCallBackWrapper { - callback: Some(callback), - data: state, - })); - self - } - - pub fn progress_callback(&mut self, mut closure: F) -> &mut Self - where - F: FnMut(i32, i32, f32) + Send + Sync + 'static, - { - let (state, callback) = unsafe { ffi_helpers::split_closure_trailing_data(&mut closure) }; - - self.progress_callback = Some(Some(ProgressCallBackWrapper { - callback: Some(callback), - data: state, - })); - self - } } #[cfg(test)] diff --git a/src/types.rs b/src/types.rs index c1eeb46..6da6480 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,3 +1,7 @@ +use std::ffi::CStr; +use std::ffi::c_char; +use std::ffi::c_void; + /// Specify the range function pub use diffusion_rs_sys::rng_type_t as RngFunction; @@ -39,3 +43,119 @@ pub enum SDImageError { #[error("The image buffer has a different length than expected")] DifferentLength, } + +#[derive(Debug)] +pub struct LogCallback { + callback: Option, + free_user_data: unsafe fn(*mut c_void), + user_data: *mut c_void, +} + +impl LogCallback { + pub fn new(f: F) -> Self { + unsafe extern "C" fn callback( + level: SdLogLevel, + text: *const c_char, + data: *mut c_void, + ) { + let f: &F = unsafe { &*data.cast_const().cast::() }; + // convert input and pass to closure + let msg = unsafe { CStr::from_ptr(text) } + .to_str() + .unwrap_or("LOG ERROR: Invalid UTF-8"); + f(level, msg.to_string()); + } + + unsafe fn free_user_data(user_data: *mut c_void) { + let user_data = user_data.cast::(); + unsafe { _ = Box::from_raw(user_data) } + } + + let user_data = Box::into_raw(Box::new(f)); + + Self { + callback: Some(callback::), + free_user_data: free_user_data::, + user_data: user_data.cast(), + } + } + + pub fn user_data(&self) -> *mut c_void { + self.user_data + } + + pub fn callback(&self) -> Option { + self.callback + } +} + +// SAFETY: A Callback can only be constructed with +// LogCallback::new, which requires F to be Send and Sync. +unsafe impl Send for LogCallback {} +unsafe impl Sync for LogCallback {} + +impl Drop for LogCallback { + fn drop(&mut self) { + unsafe { (self.free_user_data)(self.user_data) } + } +} + +#[derive(Debug)] +pub struct ProgressCallback { + callback: Option, + free_user_data: unsafe fn(*mut c_void), + user_data: *mut c_void, +} + +impl ProgressCallback { + pub fn new(f: F) -> Self { + unsafe extern "C" fn callback( + step: i32, + steps: i32, + time: f32, + data: *mut c_void, + ) { + let f: &F = unsafe { &*data.cast_const().cast::() }; + // convert input and pass to closure + f(step, steps, time); + } + + unsafe fn free_user_data(user_data: *mut c_void) { + let user_data = user_data.cast::(); + unsafe { _ = Box::from_raw(user_data) } + } + + let user_data = Box::into_raw(Box::new(f)); + + Self { + callback: Some(callback::), + free_user_data: free_user_data::, + user_data: user_data.cast(), + } + } + + pub fn user_data(&self) -> *mut c_void { + self.user_data + } + + pub fn callback(&self) -> Option { + self.callback + } +} + +// SAFETY: A Callback can only be constructed with +// ProgressCallback::new, which requires F to be Send and Sync. +unsafe impl Send for ProgressCallback {} +unsafe impl Sync for ProgressCallback {} + +impl Drop for ProgressCallback { + fn drop(&mut self) { + unsafe { (self.free_user_data)(self.user_data) } + } +} + +type UnsafeLogCallbackFn = + unsafe extern "C" fn(level: SdLogLevel, text: *const c_char, data: *mut c_void); + +type UnsafeProgressCallbackFn = + unsafe extern "C" fn(step: i32, steps: i32, time: f32, data: *mut c_void); diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 7ae8c8c..9ae8bee 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -57,19 +57,13 @@ fn test_txt2img_singlethreaded_success() { .schedule(Schedule::AYS) .vae_decode_only(true) .flash_attention(true) - .log_callback(|level, text| { - print!("({:?}): {}", level, text); - }) - .progress_callback(|step, steps, time| { - println!("Progress: {}/{} ({}s)", step, steps, time); - }) .build() .expect("Failed to build model config"); let ctx = ModelCtx::new(&model_config).expect("Failed to build model context"); let resolution: i32 = 384; - let sample_steps = 3; + let sample_steps = 2; let control_strength = 0.8; let control_image = ImageReader::open("./images/canny-384x.jpg") .expect("Failed to open image") @@ -122,12 +116,6 @@ fn test_txt2img_multithreaded_success() { .schedule(Schedule::AYS) .vae_decode_only(true) .flash_attention(false) - .log_callback(|level, text| { - print!("({:?}): {}", level, text); - }) - .progress_callback(|step, steps, time| { - println!("Progress: {}/{} ({}s)", step, steps, time); - }) .build() .expect("Failed to build model config"); @@ -201,102 +189,132 @@ fn test_txt2img_multithreaded_success() { } } -#[test] -fn test_txt2img_multithreaded_multimodel_success() { - let mut model_config = ModelConfigBuilder::default(); - model_config - .model("./models/mistoonAnime_v30.safetensors") - .lora_model_dir("./models/loras") - .taesd("./models/taesd1.safetensors") - .control_net("./models/controlnet/control_canny-fp16.safetensors") - .schedule(Schedule::AYS) - .vae_decode_only(true) - .flash_attention(true); - - let mut model_handle = vec![]; - for x in 0..2 { - let model_config = model_config - .log_callback(move |level, text| { - print!("[Thread {}] ({:?}): {}", x, level, text); - }) - .build() - .expect("Failed to build model config"); - let handle = thread::spawn(move || { - // Use the context directly in the thread - return ModelCtx::new(&model_config).expect("Failed to build model context"); - }); - model_handle.push(handle); - } - - // wait for threads to finish - let mut models = vec![]; - for handle in model_handle { - let ctx = handle.join().expect("Failed to join thread"); - models.push(ctx); - } - - let models = Arc::new(models); - - let resolution: i32 = 384; - let sample_steps = 1; - let control_strength = 0.5; - let control_image = ImageReader::open("./images/canny-384x.jpg") - .expect("Failed to open image") - .decode() - .expect("Failed to decode image") - .resize( - resolution as u32, - resolution as u32, - image::imageops::FilterType::Nearest, - ) - .into_rgb8(); - - let prompts = vec![ - "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", - "masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", - ]; - - let mut handles = vec![]; - - let mut binding = Txt2ImgConfigBuilder::default(); - let txt2img_config_base = binding - .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) - .control_cond(control_image) - .control_strength(control_strength) - .sample_steps(sample_steps) - .sample_method(SampleMethod::LCM) - .cfg_scale(1.0) - .height(resolution) - .width(resolution) - .clip_skip(2) - .batch_count(1); - - for (index, prompt) in prompts.into_iter().enumerate() { - let txt2img_config = txt2img_config_base - .prompt(prompt) - .build() - .expect("Failed to build txt2img config"); - - let models = models.clone(); - let handle = thread::spawn(move || { - let result = models[index] - .txt2img(&txt2img_config) - .expect("Failed to generate image"); - - result.iter().enumerate().for_each(|(batch, img)| { - img.save(format!( - "./images/test_mt_mm_#{}_{}x_{}.png", - index, resolution, batch - )) - .unwrap(); - }); - println!("Thread {} finished", index); - }); - - handles.push(handle); - } - - for handle in handles { - handle.join().unwrap(); - } -} +// #[test] +// fn test_txt2img_multithreaded_multimodel_success() { +// let mut model_config = ModelConfigBuilder::default(); + +// let (sender, reciever) = std::sync::mpsc::channel(); +// let sender_clone = sender.clone(); + +// // Set up a thread to receive progress updates +// let handleloop = thread::spawn(move || { +// loop { +// match reciever.recv() { +// Ok((step, steps, time)) => { +// if step == 0 && steps == 0 { +// break; +// } +// println!("Progress: {}/{} - Time: {}", step, steps, time); +// } +// Err(_) => break, +// } +// } +// }); + +// let model_on_progress = move |step, steps, time| { +// sender_clone.send((step, steps, time)).unwrap(); +// }; + +// let model_on_log = move |level, text| { +// //print!("Log: {:?}: {}", level, text); +// }; + +// ModelCtx::set_log_callback(model_on_log); +// ModelCtx::set_progress_callback(model_on_progress); + +// model_config +// .model("./models/mistoonAnime_v30.safetensors") +// .lora_model_dir("./models/loras") +// .taesd("./models/taesd1.safetensors") +// .control_net("./models/controlnet/control_canny-fp16.safetensors") +// .schedule(Schedule::AYS) +// .vae_decode_only(true) +// .flash_attention(true); + +// let mut model_handle = vec![]; +// for _ in 0..2 { +// let model_config = model_config.build().expect("Failed to build model config"); +// let handle = thread::spawn(move || { +// // Use the context directly in the thread +// return ModelCtx::new(&model_config).expect("Failed to build model context"); +// }); +// model_handle.push(handle); +// } + +// // wait for threads to finish +// let mut models = vec![]; +// for handle in model_handle { +// let ctx = handle.join().expect("Failed to join thread"); +// models.push(ctx); +// } + +// let models = Arc::new(models); + +// let resolution: i32 = 384; +// let sample_steps = 1; +// let control_strength = 0.5; +// let control_image = ImageReader::open("./images/canny-384x.jpg") +// .expect("Failed to open image") +// .decode() +// .expect("Failed to decode image") +// .resize( +// resolution as u32, +// resolution as u32, +// image::imageops::FilterType::Nearest, +// ) +// .into_rgb8(); + +// let prompts = vec![ +// "masterpiece, best quality, absurdres, 1girl, succubus, bobcut, black hair, horns, purple skin, red eyes, choker, sexy, smirk", +// "masterpiece, best quality, absurdres, 1girl, angel, long hair, blonde hair, wings, white skin, blue eyes, white dress, sexy", +// ]; + +// let mut handles = vec![]; + +// let mut binding = Txt2ImgConfigBuilder::default(); +// let txt2img_config_base = binding +// .add_lora_model("pcm_sd15_lcmlike_lora_converted", 1.0) +// .control_cond(control_image) +// .control_strength(control_strength) +// .sample_steps(sample_steps) +// .sample_method(SampleMethod::LCM) +// .cfg_scale(1.0) +// .height(resolution) +// .width(resolution) +// .clip_skip(2) +// .batch_count(1); + +// for (index, prompt) in prompts.into_iter().enumerate() { +// let txt2img_config = txt2img_config_base +// .prompt(prompt) +// .build() +// .expect("Failed to build txt2img config"); + +// let models = models.clone(); +// let handle = thread::spawn(move || { +// let result = models[index] +// .txt2img(&txt2img_config) +// .expect("Failed to generate image"); + +// result.iter().enumerate().for_each(|(batch, img)| { +// img.save(format!( +// "./images/test_mt_mm_#{}_{}x_{}.png", +// index, resolution, batch +// )) +// .unwrap(); +// }); +// println!("Thread {} finished", index); +// }); + +// handles.push(handle); +// } + +// for handle in handles { +// handle.join().unwrap(); +// } +// // Send a message to the receiver to stop the loop +// sender.send((0, 0, 0.0)).unwrap(); + +// handleloop.join().unwrap(); +// println!("All threads finished"); +// }