Skip to content

Commit

Permalink
Sketching + Inpainting Capabilities to Gradio (#2144)
Browse files Browse the repository at this point in the history
* templates

* working on backend

* formatting

* Sketching fe (#2184)

* fix scaling on sketch + bg img

* tweaks

* ketch updates

* cursor style

* sketchpad

* fixes

* ensure background is white for bw sketch

* fix everything

* re-enable demos

* updated demo and changed from dict to str

* beta release

* fix bugs, tweak webcam source

* re-anable demos

* fix clear button and tab changing

* maybe fix test

* maybe fix test again maybe

* various fixes

* fix img uplaod + color sketch

* remove lazy brush but keep smoothing

* fix sketch bg

Co-authored-by: pngwn <[email protected]>
  • Loading branch information
abidlabs and pngwn authored Sep 23, 2022
1 parent 581fbab commit cecaf1a
Show file tree
Hide file tree
Showing 20 changed files with 727 additions and 356 deletions.
Binary file added demo/all_demos/tmp.zip
Binary file not shown.
142 changes: 126 additions & 16 deletions demo/blocks_mask/run.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,135 @@
import gradio as gr
import os
from gradio.components import Markdown as md

def fn(mask):
return [mask["image"], mask["mask"]]
demo = gr.Blocks()

io1a = gr.Interface(lambda x: x, gr.Image(), gr.Image())
io1b = gr.Interface(lambda x: x, gr.Image(source="webcam"), gr.Image())

io2a = gr.Interface(lambda x: x, gr.Image(source="canvas"), gr.Image())
io2b = gr.Interface(lambda x: x, gr.Sketchpad(), gr.Image())

io3a = gr.Interface(
lambda x: [x["mask"], x["image"]],
gr.Image(source="upload", tool="sketch"),
[gr.Image(), gr.Image()],
)

io3b = gr.Interface(
lambda x: [x["mask"], x["image"]],
gr.ImageMask(),
[gr.Image(), gr.Image()],
)

io3b2 = gr.Interface(
lambda x: [x["mask"], x["image"]],
gr.ImageMask(),
[gr.Image(), gr.Image()],
)

io3b3 = gr.Interface(
lambda x: [x["mask"], x["image"]],
gr.ImageMask(),
[gr.Image(), gr.Image()],
)

io3c = gr.Interface(
lambda x: [x["mask"], x["image"]],
gr.Image(source="webcam", tool="sketch"),
[gr.Image(), gr.Image()],
)

io4a = gr.Interface(
lambda x: x, gr.Image(source="canvas", tool="color-sketch"), gr.Image()
)
io4b = gr.Interface(lambda x: x, gr.Paint(), gr.Image())

io5a = gr.Interface(
lambda x: x, gr.Image(source="upload", tool="color-sketch"), gr.Image()
)
io5b = gr.Interface(lambda x: x, gr.ImagePaint(), gr.Image())
io5c = gr.Interface(
lambda x: x, gr.Image(source="webcam", tool="color-sketch"), gr.Image()
)

demo = gr.Blocks()

with demo:
with gr.Row():
with gr.Column():
img = gr.Image(
tool="sketch", source="upload", label="Mask", value=os.path.join(os.path.dirname(__file__), "lion.jpg")
)
with gr.Row():
btn = gr.Button("Run")
with gr.Column():
img2 = gr.Image()
img3 = gr.Image()

btn.click(fn=fn, inputs=img, outputs=[img2, img3])
md("# Different Ways to Use the Image Input Component")
md(
"**1a. Standalone Image Upload: `gr.Interface(lambda x: x, gr.Image(), gr.Image())`**"
)
io1a.render()
md(
"**1b. Standalone Image from Webcam: `gr.Interface(lambda x: x, gr.Image(source='webcam'), gr.Image())`**"
)
io1b.render()
md(
"**2a. Black and White Sketchpad: `gr.Interface(lambda x: x, gr.Image(source='canvas'), gr.Image())`**"
)
io2a.render()
md(
"**2b. Black and White Sketchpad: `gr.Interface(lambda x: x, gr.Sketchpad(), gr.Image())`**"
)
io2b.render()
md("**3a. Binary Mask with image upload:**")
md(
"""```python
gr.Interface(
lambda x: [x['mask'], x['image']],
gr.Image(source='upload', tool='sketch'),
[gr.Image(), gr.Image()],
)
```
"""
)
io3a.render()
md("**3b. Binary Mask with image upload:**")
md(
"""```python
gr.Interface(
lambda x: [x['mask'], x['image']],
gr.ImageMask(),
[gr.Image(), gr.Image()],
)
```
"""
)
io3b.render()
md("**3c. Binary Mask with webcam upload:**")
md(
"""```python
gr.Interface(
lambda x: [x['mask'], x['image']],
gr.Image(source='webcam', tool='sketch'),
[gr.Image(), gr.Image()],
)
```
"""
)
io3c.render()
md(
"**4a. Color Sketchpad: `gr.Interface(lambda x: x, gr.Image(source='canvas', tool='color-sketch'), gr.Image())`**"
)
io4a.render()
md("**4b. Color Sketchpad: `gr.Interface(lambda x: x, gr.Paint(), gr.Image())`**")
io4b.render()
md(
"**5a. Color Sketchpad with image upload: `gr.Interface(lambda x: x, gr.Image(source='upload', tool='color-sketch'), gr.Image())`**"
)
io5a.render()
md(
"**5b. Color Sketchpad with image upload: `gr.Interface(lambda x: x, gr.ImagePaint(), gr.Image())`**"
)
io5b.render()
md(
"**5c. Color Sketchpad with webcam upload: `gr.Interface(lambda x: x, gr.Image(source='webcam', tool='color-sketch'), gr.Image())`**"
)
io5c.render()
md("**Tabs**")
with gr.Tab("One"):
io3b2.render()
with gr.Tab("Two"):
io3b3.render()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion demo/filter_records/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def filter_records(records, gender):
headers=["name", "age", "gender"],
datatype=["str", "number", "str"],
row_count=5,
col_count=(3, "fixed")
col_count=(3, "fixed"),
),
gr.Dropdown(["M", "F", "O"]),
],
Expand Down
3 changes: 3 additions & 0 deletions gradio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@
from gradio.templates import (
Files,
Highlight,
ImageMask,
ImagePaint,
List,
Matrix,
Mic,
Microphone,
Numpy,
Paint,
Pil,
PlayableVideo,
Sketchpad,
Expand Down
46 changes: 26 additions & 20 deletions gradio/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DataframeData(TypedDict):
import numpy as np
import pandas as pd
import PIL
import PIL.ImageOps
from ffmpy import FFmpeg
from markdown_it import MarkdownIt

Expand Down Expand Up @@ -1175,11 +1176,11 @@ def style(
)


@document("edit", "clear", "change", "stream", "change")
@document("edit", "clear", "change", "stream", "change", "style")
class Image(Editable, Clearable, Changeable, Streamable, IOComponent, ImgSerializable):
"""
Creates an image component that can be used to upload/draw images (as an input) or display images (as an output).
Preprocessing: passes the uploaded image as a {numpy.array}, {PIL.Image} or {str} filepath depending on `type` -- unless `tool` is `sketch`. In the special case, a {dict} with keys `image` and `mask` is passed, and the format of the corresponding values depends on `type`.
Preprocessing: passes the uploaded image as a {numpy.array}, {PIL.Image} or {str} filepath depending on `type` -- unless `tool` is `sketch` AND source is one of `upload` or `webcam`. In these cases, a {dict} with keys `image` and `mask` is passed, and the format of the corresponding values depends on `type`.
Postprocessing: expects a {numpy.array}, {PIL.Image} or {str} or {pathlib.Path} filepath to an image and displays the image.
Examples-format: a {str} filepath to a local file that contains the image.
Demos: image_mod, image_mod_default_image
Expand All @@ -1194,7 +1195,7 @@ def __init__(
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "upload",
tool: str = "editor",
tool: str = None,
type: str = "numpy",
label: Optional[str] = None,
show_label: bool = True,
Expand All @@ -1212,7 +1213,7 @@ def __init__(
image_mode: "RGB" if color, or "L" if black and white.
invert_colors: whether to invert the image as a preprocessing step.
source: Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
tool: Tools used for editing. "editor" allows a full screen editor, "select" provides a cropping and zoom tool, "sketch" allows you to create a mask over the image and both the image and mask are passed into the function.
tool: Tools used for editing. "editor" allows a full screen editor (and is the default if source is "upload" or "webcam"), "select" provides a cropping and zoom tool, "sketch" allows you to create a binary sketch (and is the default if source="canvas"), and "color-sketch" allows you to created a sketch in different colors. "color-sketch" can be used with source="upload" or "webcam" to allow sketching on an image. "sketch" can also be used with "upload" or "webcam" to create a mask over an image and in that case both the image and mask are passed into the function as a dictionary with keys "image" and "mask" respectively.
type: The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "file" produces a temporary file object whose path can be retrieved by file_obj.name, "filepath" passes a str path to a temporary file containing the image.
label: component name in interface.
show_label: if True, will display label.
Expand All @@ -1228,7 +1229,10 @@ def __init__(
self.image_mode = image_mode
self.source = source
requires_permissions = source == "webcam"
self.tool = tool
if tool is None:
self.tool = "sketch" if source == "canvas" else "editor"
else:
self.tool = tool
self.invert_colors = invert_colors
self.test_input = deepcopy(media_data.BASE64_IMAGE)
self.interpret_by_tokens = True
Expand Down Expand Up @@ -1279,9 +1283,10 @@ def update(
return IOComponent.add_interactive_to_config(updated_config, interactive)

def _format_image(
self, im: Optional[PIL.Image], fmt: str
self, im: Optional[PIL.Image]
) -> np.array | PIL.Image | str | None:
"""Helper method to format an image based on self.type"""
fmt = im.format
if im is None:
return im
if self.type == "pil":
Expand Down Expand Up @@ -1314,36 +1319,37 @@ def generate_sample(self):
def preprocess(self, x: str | Dict) -> np.array | PIL.Image | str | None:
"""
Parameters:
x: base64 url data, or (if tool == "sketch) a dict of image and mask base64 url data
x: base64 url data, or (if tool == "sketch") a dict of image and mask base64 url data
Returns:
image in requested format
image in requested format, or (if tool == "sketch") a dict of image and mask in requested format
"""
if x is None:
return x
if self.tool == "sketch":
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
x, mask = x["image"], x["mask"]

im = processing_utils.decode_base64_to_image(x)
fmt = im.format
with warnings.catch_warnings():
warnings.simplefilter("ignore")
im = im.convert(self.image_mode)
if self.shape is not None:
im = processing_utils.resize_and_crop(im, self.shape)
if self.invert_colors:
im = PIL.ImageOps.invert(im)
if self.source == "webcam" and self.mirror_webcam is True:
if (
self.source == "webcam"
and self.mirror_webcam is True
and self.tool != "color-sketch"
):
im = PIL.ImageOps.mirror(im)

if not (self.tool == "sketch"):
return self._format_image(im, fmt)
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
mask_im = processing_utils.decode_base64_to_image(mask)
return {
"image": self._format_image(im),
"mask": self._format_image(mask_im),
}

mask_im = processing_utils.decode_base64_to_image(mask)
mask_fmt = mask_im.format
return {
"image": self._format_image(im, fmt),
"mask": self._format_image(mask_im, mask_fmt),
}
return self._format_image(im)

def postprocess(self, y: np.ndarray | PIL.Image | str | Path) -> str:
"""
Expand Down
40 changes: 39 additions & 1 deletion gradio/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Webcam(components.Image):
is_template = True

def __init__(self, **kwargs):
super().__init__(source="webcam", **kwargs)
super().__init__(source="webcam", interactive=True, **kwargs)


class Sketchpad(components.Image):
Expand All @@ -47,10 +47,48 @@ def __init__(self, **kwargs):
source="canvas",
shape=(28, 28),
invert_colors=True,
interactive=True,
**kwargs
)


class Paint(components.Image):
"""
Sets source="canvas", tool="color-sketch"
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(
source="canvas", tool="color-sketch", interactive=True, **kwargs
)


class ImageMask(components.Image):
"""
Sets source="canvas", tool="sketch"
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)


class ImagePaint(components.Image):
"""
Sets source="upload", tool="color-sketch"
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(
source="upload", tool="color-sketch", interactive=True, **kwargs
)


class Pil(components.Image):
"""
Sets: type="pil"
Expand Down
2 changes: 1 addition & 1 deletion gradio/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.3.1
3.4b0
4 changes: 2 additions & 2 deletions ui/packages/app/test/blocks_xray.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ test("can run an api request and display the data", async ({ page }) => {
await page.check("label:has-text('Covid')");
await page.check("label:has-text('Lung Cancer')");

const run_button = await page.locator("button", { hasText: /Run/ });
const run_button = await page.locator("button", { hasText: /Run/ }).first();

await Promise.all([
run_button.click(),
page.waitForResponse("**/api/predict/")
]);

const json = await page.locator("data-testid=json");
const json = await page.locator("data-testid=json").first();
await expect(json).toContainText(`Covid: 0.75, Lung Cancer: 0.25`);
});
9 changes: 9 additions & 0 deletions ui/packages/icons/src/Brush.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<svg width="100%" height="100%" viewBox="0 0 32 32"
><path
d="M28.828 3.172a4.094 4.094 0 0 0-5.656 0L4.05 22.292A6.954 6.954 0 0 0 2 27.242V30h2.756a6.952 6.952 0 0 0 4.95-2.05L28.828 8.829a3.999 3.999 0 0 0 0-5.657zM10.91 18.26l2.829 2.829l-2.122 2.121l-2.828-2.828zm-2.619 8.276A4.966 4.966 0 0 1 4.756 28H4v-.759a4.967 4.967 0 0 1 1.464-3.535l1.91-1.91l2.829 2.828zM27.415 7.414l-12.261 12.26l-2.829-2.828l12.262-12.26a2.047 2.047 0 0 1 2.828 0a2 2 0 0 1 0 2.828z"
fill="currentColor"
/><path
d="M6.5 15a3.5 3.5 0 0 1-2.475-5.974l3.5-3.5a1.502 1.502 0 0 0 0-2.121a1.537 1.537 0 0 0-2.121 0L3.415 5.394L2 3.98l1.99-1.988a3.585 3.585 0 0 1 4.95 0a3.504 3.504 0 0 1 0 4.949L5.439 10.44a1.502 1.502 0 0 0 0 2.121a1.537 1.537 0 0 0 2.122 0l4.024-4.024L13 9.95l-4.025 4.024A3.475 3.475 0 0 1 6.5 15z"
fill="currentColor"
/></svg
>
Loading

0 comments on commit cecaf1a

Please sign in to comment.