Skip to content

Commit 2995c92

Browse files
authored
Tiled vae parameter validation (#6)
* avoid crash with invalid tile sizes, use 0 for default * refactor default tile size, limit overlap factor * remove explicit parameter for relative tile size * limit encoding tile to latent size
1 parent 987ced8 commit 2995c92

File tree

3 files changed

+34
-43
lines changed

3 files changed

+34
-43
lines changed

examples/cli/main.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ struct SDParams {
118118
int chroma_t5_mask_pad = 1;
119119
float flow_shift = INFINITY;
120120

121-
sd_tiling_params_t vae_tiling_params = {false, 32, 32, 0.5f, false, 0.0f, 0.0f};
121+
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
122122

123123
SDParams() {
124124
sd_sample_params_init(&sample_params);
@@ -749,7 +749,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
749749
} catch (const std::out_of_range& e) {
750750
return -1;
751751
}
752-
params.vae_tiling_params.relative = false;
753752
return 1;
754753
};
755754

@@ -773,7 +772,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
773772
} catch (const std::out_of_range& e) {
774773
return -1;
775774
}
776-
params.vae_tiling_params.relative = true;
777775
return 1;
778776
};
779777

stable-diffusion.cpp

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class StableDiffusionGGML {
108108

109109
std::string taesd_path;
110110
bool use_tiny_autoencoder = false;
111-
sd_tiling_params_t vae_tiling_params = {false, 32, 32, 0.5f, false, 0, 0};
111+
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0};
112112
bool offload_params_to_cpu = false;
113113
bool stacked_id = false;
114114

@@ -1300,27 +1300,30 @@ class StableDiffusionGGML {
13001300
return latent;
13011301
}
13021302

1303-
void get_relative_tile_sizes(int& tile_size_x, int& tile_size_y, float tile_overlap, float rel_size_x, float rel_size_y, int latent_x, int latent_y) {
1304-
// format is AxB, or just A (equivalent to AxA)
1305-
// A and B can be integers (tile size) or floating point
1306-
// floating point <= 1 means simple fraction of the latent dimension
1307-
// floating point > 1 means number of tiles across that dimension
1308-
// a single number gets applied to both
1309-
auto get_tile_factor = [tile_overlap](float factor) {
1310-
if (factor > 1.0)
1311-
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1312-
return factor;
1313-
};
1314-
const int min_tile_dimension = 4;
1315-
1316-
int tmp_x = tile_size_x, tmp_y = tile_size_y;
1317-
tmp_x = std::round(latent_x * get_tile_factor(rel_size_x));
1318-
tmp_y = std::round(latent_y * get_tile_factor(rel_size_y));
13191303

1320-
tile_size_x = std::max(std::min(tmp_x, latent_x), min_tile_dimension);
1321-
tile_size_y = std::max(std::min(tmp_y, latent_y), min_tile_dimension);
1304+
void get_tile_sizes(int& tile_size_x, int& tile_size_y, float& tile_overlap, const sd_tiling_params_t & params,
1305+
int latent_x, int latent_y, float encoding_factor = 1.0f) {
1306+
tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f);
1307+
auto get_tile_size = [&](int requested_size, float factor, int latent_size) {
1308+
const int default_tile_size = 32;
1309+
const int min_tile_dimension = 4;
1310+
int tile_size = default_tile_size;
1311+
// rel_size <= 1 means simple fraction of the latent dimension
1312+
// rel_size > 1 means number of tiles across that dimension
1313+
if (factor > 0.f) {
1314+
if (factor > 1.0)
1315+
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
1316+
tile_size = std::round(latent_size * factor);
1317+
}
1318+
else if (requested_size >= min_tile_dimension) {
1319+
tile_size = requested_size;
1320+
}
1321+
tile_size *= encoding_factor;
1322+
return std::max(std::min(tile_size, latent_size), min_tile_dimension);
1323+
};
13221324

1323-
LOG_INFO("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
1325+
tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x);
1326+
tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y);
13241327
}
13251328

13261329
ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
@@ -1336,20 +1339,14 @@ class StableDiffusionGGML {
13361339
}
13371340
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, x->ne[3]);
13381341
}
1339-
// TODO: args instead of env for tile size / overlap?
1340-
if (!use_tiny_autoencoder) {
1341-
float tile_overlap = vae_tiling_params.target_overlap;
1342-
int tile_size_x = vae_tiling_params.tile_size_x;
1343-
int tile_size_y = vae_tiling_params.tile_size_y;
13441342

1345-
if (vae_tiling_params.relative) {
1346-
get_relative_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params.rel_size_x, vae_tiling_params.rel_size_y, W, H);
1347-
}
1348-
1349-
// TODO: also use an arg for this one?
1343+
if (!use_tiny_autoencoder) {
1344+
float tile_overlap;
1345+
int tile_size_x, tile_size_y;
13501346
// multiply tile size for encode to keep the compute buffer size consistent
1351-
tile_size_x *= 1.30539;
1352-
tile_size_y *= 1.30539;
1347+
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, W, H, 1.30539);
1348+
1349+
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
13531350

13541351
process_vae_input_tensor(x);
13551352
if (vae_tiling_params.enabled && !decode_video) {
@@ -1489,13 +1486,10 @@ class StableDiffusionGGML {
14891486
}
14901487
int64_t t0 = ggml_time_ms();
14911488
if (!use_tiny_autoencoder) {
1492-
float tile_overlap = vae_tiling_params.target_overlap;
1493-
int tile_size_x = vae_tiling_params.tile_size_x;
1494-
int tile_size_y = vae_tiling_params.tile_size_y;
1489+
float tile_overlap;
1490+
int tile_size_x, tile_size_y;
1491+
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, W, H);
14951492

1496-
if (vae_tiling_params.relative) {
1497-
get_relative_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params.rel_size_x, vae_tiling_params.rel_size_y, x->ne[0], x->ne[1]);
1498-
}
14991493
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
15001494

15011495
process_latent_out(x);
@@ -1769,7 +1763,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
17691763
sd_img_gen_params->control_strength = 0.9f;
17701764
sd_img_gen_params->style_strength = 20.f;
17711765
sd_img_gen_params->normalize_input = false;
1772-
sd_img_gen_params->vae_tiling_params = {false, 32, 32, 0.5f, false, 0.0f, 0.0f};
1766+
sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
17731767
}
17741768

17751769
char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {

stable-diffusion.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ typedef struct {
118118
int tile_size_x;
119119
int tile_size_y;
120120
float target_overlap;
121-
bool relative;
122121
float rel_size_x;
123122
float rel_size_y;
124123
} sd_tiling_params_t;

0 commit comments

Comments
 (0)