diff --git a/backend/backend.proto b/backend/backend.proto index 2ea7f6f1048a..3692180100f4 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -301,7 +301,6 @@ message TranscriptSegment { message GenerateImageRequest { int32 height = 1; int32 width = 2; - int32 mode = 3; int32 step = 4; int32 seed = 5; string positive_prompt = 6; diff --git a/core/backend/image.go b/core/backend/image.go index b6bb4f8a74e0..651293cf5e1b 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -7,7 +7,7 @@ import ( model "github.com/mudler/LocalAI/pkg/model" ) -func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) { +func ImageGeneration(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) { opts := ModelOptions(modelConfig, appConfig) inferenceModel, err := loader.Load( @@ -23,7 +23,6 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat &proto.GenerateImageRequest{ Height: int32(height), Width: int32(width), - Mode: int32(mode), Step: int32(step), Seed: int32(seed), CLIPSkip: int32(modelConfig.Diffusers.ClipSkip), diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index 025abaa9418f..3575fee2b167 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -157,16 +157,11 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi negative_prompt = prompts[1] } - mode := 0 step := config.Step if step == 0 { step = 15 } - if input.Mode != 0 { - mode = input.Mode - } - if input.Step != 0 { step = input.Step } @@ -197,7 +192,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi inputSrc = inputImages[0] } - fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, inputSrc, output, ml, *config, appConfig, refImages) + fn, err := backend.ImageGeneration(height, width, step, *config.Seed, positive_prompt, negative_prompt, inputSrc, output, ml, *config, appConfig, refImages) if err != nil { return err } diff --git a/core/http/endpoints/openai/inpainting.go b/core/http/endpoints/openai/inpainting.go index 90c5a05d18cf..a27ffea54dc9 100644 --- a/core/http/endpoints/openai/inpainting.go +++ b/core/http/endpoints/openai/inpainting.go @@ -231,7 +231,7 @@ func InpaintingEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app // Note: ImageGenerationFunc will call into the loaded model's GenerateImage which expects src JSON // Also pass ref images (orig + mask) so backends that support ref images can use them. refImages := []string{origRef, maskRef} - fn, err := backend.ImageGenerationFunc(height, width, 0, steps, 0, prompt, "", jsonPath, dst, ml, *cfg, appConfig, refImages) + fn, err := backend.ImageGenerationFunc(height, width, steps, 0, prompt, "", jsonPath, dst, ml, *cfg, appConfig, refImages) if err != nil { return err } diff --git a/core/http/endpoints/openai/inpainting_test.go b/core/http/endpoints/openai/inpainting_test.go index 6b57e4a3dcdf..de4678d347e8 100644 --- a/core/http/endpoints/openai/inpainting_test.go +++ b/core/http/endpoints/openai/inpainting_test.go @@ -10,9 +10,9 @@ import ( "testing" "github.com/labstack/echo/v4" - "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" model "github.com/mudler/LocalAI/pkg/model" "github.com/stretchr/testify/require" ) @@ -58,7 +58,7 @@ func TestInpainting_HappyPath(t *testing.T) { // stub the backend.ImageGenerationFunc orig := backend.ImageGenerationFunc - backend.ImageGenerationFunc = func(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) { + backend.ImageGenerationFunc = func(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) { fn := func() error { // write a fake png file to dst return os.WriteFile(dst, []byte("PNGDATA"), 0644) diff --git a/core/http/static/image.js b/core/http/static/image.js index 24a06557875b..16b18735fd57 100644 --- a/core/http/static/image.js +++ b/core/http/static/image.js @@ -1,61 +1,255 @@ +// Helper function to convert file to base64 +function fileToBase64(file) { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => { + // Remove data:image/...;base64, prefix if present + const base64 = reader.result.split(',')[1] || reader.result; + resolve(base64); + }; + reader.onerror = reject; + reader.readAsDataURL(file); + }); +} + +// Helper function to read multiple files +async function filesToBase64Array(fileList) { + const base64Array = []; + for (let i = 0; i < fileList.length; i++) { + const base64 = await fileToBase64(fileList[i]); + base64Array.push(base64); + } + return base64Array; +} + function genImage(event) { event.preventDefault(); - const input = document.getElementById("input").value; - - promptDallE(input); + promptDallE(); } - -async function promptDallE(input) { - document.getElementById("loader").style.display = "block"; - document.getElementById("input").value = ""; - document.getElementById("input").disabled = true; +async function promptDallE() { + const loader = document.getElementById("loader"); + const input = document.getElementById("input"); + const generateBtn = document.getElementById("generate-btn"); + const resultDiv = document.getElementById("result"); + + // Show loader and disable form + loader.style.display = "block"; + input.disabled = true; + generateBtn.disabled = true; + + // Store the prompt for later restoration + const prompt = input.value.trim(); + if (!prompt) { + alert("Please enter a prompt"); + loader.style.display = "none"; + input.disabled = false; + generateBtn.disabled = false; + return; + } + + // Collect all form values const model = document.getElementById("image-model").value; const size = document.getElementById("image-size").value; - const response = await fetch("v1/images/generations", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - model: model, - steps: 10, - prompt: input, - n: 1, - size: size, - }), - }); - const json = await response.json(); - if (json.error) { - // Display error if there is one - var div = document.getElementById('result'); // Get the div by its ID - div.innerHTML = '
' + json.error.message + '
'; + const negativePrompt = document.getElementById("negative-prompt").value.trim(); + const n = parseInt(document.getElementById("image-count").value) || 1; + const stepInput = document.getElementById("image-steps").value.trim(); + const step = stepInput ? parseInt(stepInput) : undefined; + const seedInput = document.getElementById("image-seed").value.trim(); + const seed = seedInput ? parseInt(seedInput) : undefined; + + // Prepare request body + // Combine prompt and negative prompt with "|" separator (backend expects this format) + let combinedPrompt = prompt; + if (negativePrompt) { + combinedPrompt = prompt + "|" + negativePrompt; + } + + const requestBody = { + model: model, + prompt: combinedPrompt, + n: n, + size: size, + }; + + if (step !== undefined) { + requestBody.step = step; + } + + if (seed !== undefined) { + requestBody.seed = seed; + } + + // Handle file inputs + try { + // Source image (single file for img2img) + const sourceImageInput = document.getElementById("source-image"); + if (sourceImageInput.files.length > 0) { + const base64 = await fileToBase64(sourceImageInput.files[0]); + requestBody.file = base64; + } + + // Reference images (collect from all dynamic inputs) + const refImageInputs = document.querySelectorAll('.reference-image-file'); + const refImageFiles = []; + for (const input of refImageInputs) { + if (input.files.length > 0) { + refImageFiles.push(input.files[0]); + } + } + if (refImageFiles.length > 0) { + const base64Array = await filesToBase64Array(refImageFiles); + requestBody.ref_images = base64Array; + } + } catch (error) { + console.error("Error processing image files:", error); + resultDiv.innerHTML = 'Error processing image files: ' + error.message + '
'; + loader.style.display = "none"; + input.disabled = false; + generateBtn.disabled = false; return; } - const url = json.data[0].url; - var div = document.getElementById('result'); // Get the div by its ID - var img = document.createElement('img'); // Create a new img element - img.src = url; // Set the source of the image - img.alt = 'Generated image'; // Set the alt text of the image + // Make API request + try { + const response = await fetch("v1/images/generations", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(requestBody), + }); - div.innerHTML = ''; // Clear the existing content of the div - div.appendChild(img); // Add the new img element to the div + const json = await response.json(); - document.getElementById("loader").style.display = "none"; - document.getElementById("input").disabled = false; - document.getElementById("input").focus(); + if (json.error) { + // Display error + resultDiv.innerHTML = 'Error: ' + json.error.message + '
'; + loader.style.display = "none"; + input.disabled = false; + generateBtn.disabled = false; + return; + } + + // Clear result div + resultDiv.innerHTML = ''; + + // Display all generated images + if (json.data && json.data.length > 0) { + json.data.forEach((item, index) => { + const imageContainer = document.createElement("div"); + imageContainer.className = "mb-6 bg-[var(--color-bg-primary)]/50 border border-[#1E293B] rounded-xl p-4"; + + // Create image element + const img = document.createElement("img"); + if (item.url) { + img.src = item.url; + } else if (item.b64_json) { + img.src = "data:image/png;base64," + item.b64_json; + } else { + return; // Skip invalid items + } + img.alt = prompt; + img.className = "w-full h-auto rounded-lg mb-3"; + imageContainer.appendChild(img); + + // Create caption container + const captionDiv = document.createElement("div"); + captionDiv.className = "mt-3 p-3 bg-[var(--color-bg-secondary)] rounded-lg"; + + // Prompt caption + const promptCaption = document.createElement("p"); + promptCaption.className = "text-sm text-[var(--color-text-primary)] mb-2"; + promptCaption.innerHTML = 'Prompt: ' + escapeHtml(prompt); + captionDiv.appendChild(promptCaption); + + // Negative prompt if provided + if (negativePrompt) { + const negativeCaption = document.createElement("p"); + negativeCaption.className = "text-sm text-[var(--color-text-secondary)] mb-2"; + negativeCaption.innerHTML = 'Negative Prompt: ' + escapeHtml(negativePrompt); + captionDiv.appendChild(negativeCaption); + } + + // Generation details + const detailsDiv = document.createElement("div"); + detailsDiv.className = "flex flex-wrap gap-4 text-xs text-[var(--color-text-secondary)] mt-2"; + detailsDiv.innerHTML = ` + Size: ${size} + ${step !== undefined ? `Steps: ${step}` : ''} + ${seed !== undefined ? `Seed: ${seed}` : ''} + `; + captionDiv.appendChild(detailsDiv); + + // Copy prompt button + const copyBtn = document.createElement("button"); + copyBtn.className = "mt-2 px-3 py-1 text-xs bg-[var(--color-primary)] text-white rounded hover:opacity-80"; + copyBtn.innerHTML = 'Copy Prompt'; + copyBtn.onclick = () => { + navigator.clipboard.writeText(prompt).then(() => { + copyBtn.innerHTML = 'Copied!'; + setTimeout(() => { + copyBtn.innerHTML = 'Copy Prompt'; + }, 2000); + }); + }; + captionDiv.appendChild(copyBtn); + + imageContainer.appendChild(captionDiv); + resultDiv.appendChild(imageContainer); + }); + } else { + resultDiv.innerHTML = 'No images were generated.
'; + } + + // Preserve prompt in input field (don't clear it) + // The prompt is already in the input field, so we don't need to restore it + + } catch (error) { + console.error("Error generating image:", error); + resultDiv.innerHTML = 'Error: ' + error.message + '
'; + } finally { + // Hide loader and re-enable form + loader.style.display = "none"; + input.disabled = false; + generateBtn.disabled = false; + input.focus(); + } +} + +// Helper function to escape HTML +function escapeHtml(text) { + const div = document.createElement("div"); + div.textContent = text; + return div.innerHTML; } -document.getElementById("input").focus(); -document.getElementById("genimage").addEventListener("submit", genImage); +// Initialize +document.addEventListener("DOMContentLoaded", function() { + const input = document.getElementById("input"); + const form = document.getElementById("genimage"); + + if (input) { + input.focus(); + } + + if (form) { + form.addEventListener("submit", genImage); + } -// Handle Enter key press in the prompt input -document.getElementById("input").addEventListener("keypress", function(event) { - if (event.key === "Enter") { + // Handle Enter key press in the prompt input (but allow Shift+Enter for new lines) + if (input) { + input.addEventListener("keydown", function(event) { + if (event.key === "Enter" && !event.shiftKey) { event.preventDefault(); genImage(event); - } -}); + } + }); + } -document.getElementById("loader").style.display = "none"; + // Hide loader initially + const loader = document.getElementById("loader"); + if (loader) { + loader.style.display = "none"; + } +}); diff --git a/core/http/views/text2image.html b/core/http/views/text2image.html index 69f0a925400e..2955a6b400f6 100644 --- a/core/http/views/text2image.html +++ b/core/http/views/text2image.html @@ -55,46 +55,176 @@