Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sketching + Inpainting Capabilities to Gradio #2144

Merged
merged 22 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added demo/all_demos/tmp.zip
Binary file not shown.
142 changes: 126 additions & 16 deletions demo/blocks_mask/run.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,135 @@
import gradio as gr
import os
from gradio.components import Markdown as md

def fn(mask):
return [mask["image"], mask["mask"]]
demo = gr.Blocks()

io1a = gr.Interface(lambda x: x, gr.Image(), gr.Image())
io1b = gr.Interface(lambda x: x, gr.Image(source="webcam"), gr.Image())

io2a = gr.Interface(lambda x: x, gr.Image(source="canvas"), gr.Image())
io2b = gr.Interface(lambda x: x, gr.Sketchpad(), gr.Image())

io3a = gr.Interface(
lambda x: [x["mask"], x["image"]],
gr.Image(source="upload", tool="sketch"),
[gr.Image(), gr.Image()],
)

io3b = gr.Interface(
lambda x: [x["mask"], x["image"]],
gr.ImageMask(),
[gr.Image(), gr.Image()],
)

io3b2 = gr.Interface(
lambda x: [x["mask"], x["image"]],
gr.ImageMask(),
[gr.Image(), gr.Image()],
)

io3b3 = gr.Interface(
lambda x: [x["mask"], x["image"]],
gr.ImageMask(),
[gr.Image(), gr.Image()],
)

io3c = gr.Interface(
lambda x: [x["mask"], x["image"]],
gr.Image(source="webcam", tool="sketch"),
[gr.Image(), gr.Image()],
)

io4a = gr.Interface(
lambda x: x, gr.Image(source="canvas", tool="color-sketch"), gr.Image()
)
io4b = gr.Interface(lambda x: x, gr.Paint(), gr.Image())

io5a = gr.Interface(
lambda x: x, gr.Image(source="upload", tool="color-sketch"), gr.Image()
)
io5b = gr.Interface(lambda x: x, gr.ImagePaint(), gr.Image())
io5c = gr.Interface(
lambda x: x, gr.Image(source="webcam", tool="color-sketch"), gr.Image()
)

demo = gr.Blocks()

with demo:
with gr.Row():
with gr.Column():
img = gr.Image(
tool="sketch", source="upload", label="Mask", value=os.path.join(os.path.dirname(__file__), "lion.jpg")
)
with gr.Row():
btn = gr.Button("Run")
with gr.Column():
img2 = gr.Image()
img3 = gr.Image()

btn.click(fn=fn, inputs=img, outputs=[img2, img3])
md("# Different Ways to Use the Image Input Component")
md(
"**1a. Standalone Image Upload: `gr.Interface(lambda x: x, gr.Image(), gr.Image())`**"
)
io1a.render()
md(
"**1b. Standalone Image from Webcam: `gr.Interface(lambda x: x, gr.Image(source='webcam'), gr.Image())`**"
)
io1b.render()
md(
"**2a. Black and White Sketchpad: `gr.Interface(lambda x: x, gr.Image(source='canvas'), gr.Image())`**"
)
io2a.render()
md(
"**2b. Black and White Sketchpad: `gr.Interface(lambda x: x, gr.Sketchpad(), gr.Image())`**"
)
io2b.render()
md("**3a. Binary Mask with image upload:**")
md(
"""```python
gr.Interface(
lambda x: [x['mask'], x['image']],
gr.Image(source='upload', tool='sketch'),
[gr.Image(), gr.Image()],
)
```
"""
)
io3a.render()
md("**3b. Binary Mask with image upload:**")
md(
"""```python
gr.Interface(
lambda x: [x['mask'], x['image']],
gr.ImageMask(),
[gr.Image(), gr.Image()],
)
```
"""
)
io3b.render()
md("**3c. Binary Mask with webcam upload:**")
md(
"""```python
gr.Interface(
lambda x: [x['mask'], x['image']],
gr.Image(source='webcam', tool='sketch'),
[gr.Image(), gr.Image()],
)
```
"""
)
io3c.render()
md(
"**4a. Color Sketchpad: `gr.Interface(lambda x: x, gr.Image(source='canvas', tool='color-sketch'), gr.Image())`**"
)
io4a.render()
md("**4b. Color Sketchpad: `gr.Interface(lambda x: x, gr.Paint(), gr.Image())`**")
io4b.render()
md(
"**5a. Color Sketchpad with image upload: `gr.Interface(lambda x: x, gr.Image(source='upload', tool='color-sketch'), gr.Image())`**"
)
io5a.render()
md(
"**5b. Color Sketchpad with image upload: `gr.Interface(lambda x: x, gr.ImagePaint(), gr.Image())`**"
)
io5b.render()
md(
"**5c. Color Sketchpad with webcam upload: `gr.Interface(lambda x: x, gr.Image(source='webcam', tool='color-sketch'), gr.Image())`**"
)
io5c.render()
md("**Tabs**")
with gr.Tab("One"):
io3b2.render()
with gr.Tab("Two"):
io3b3.render()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion demo/filter_records/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def filter_records(records, gender):
headers=["name", "age", "gender"],
datatype=["str", "number", "str"],
row_count=5,
col_count=(3, "fixed")
col_count=(3, "fixed"),
),
gr.Dropdown(["M", "F", "O"]),
],
Expand Down
3 changes: 3 additions & 0 deletions gradio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@
from gradio.templates import (
Files,
Highlight,
ImageMask,
ImagePaint,
List,
Matrix,
Mic,
Microphone,
Numpy,
Paint,
Pil,
PlayableVideo,
Sketchpad,
Expand Down
46 changes: 26 additions & 20 deletions gradio/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DataframeData(TypedDict):
import numpy as np
import pandas as pd
import PIL
import PIL.ImageOps
from ffmpy import FFmpeg
from markdown_it import MarkdownIt

Expand Down Expand Up @@ -1175,11 +1176,11 @@ def style(
)


@document("edit", "clear", "change", "stream", "change")
@document("edit", "clear", "change", "stream", "change", "style")
class Image(Editable, Clearable, Changeable, Streamable, IOComponent, ImgSerializable):
"""
Creates an image component that can be used to upload/draw images (as an input) or display images (as an output).
Preprocessing: passes the uploaded image as a {numpy.array}, {PIL.Image} or {str} filepath depending on `type` -- unless `tool` is `sketch`. In the special case, a {dict} with keys `image` and `mask` is passed, and the format of the corresponding values depends on `type`.
Preprocessing: passes the uploaded image as a {numpy.array}, {PIL.Image} or {str} filepath depending on `type` -- unless `tool` is `sketch` AND source is one of `upload` or `webcam`. In these cases, a {dict} with keys `image` and `mask` is passed, and the format of the corresponding values depends on `type`.
Postprocessing: expects a {numpy.array}, {PIL.Image} or {str} or {pathlib.Path} filepath to an image and displays the image.
Examples-format: a {str} filepath to a local file that contains the image.
Demos: image_mod, image_mod_default_image
Expand All @@ -1194,7 +1195,7 @@ def __init__(
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "upload",
tool: str = "editor",
tool: str = None,
type: str = "numpy",
label: Optional[str] = None,
show_label: bool = True,
Expand All @@ -1212,7 +1213,7 @@ def __init__(
image_mode: "RGB" if color, or "L" if black and white.
invert_colors: whether to invert the image as a preprocessing step.
source: Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
tool: Tools used for editing. "editor" allows a full screen editor, "select" provides a cropping and zoom tool, "sketch" allows you to create a mask over the image and both the image and mask are passed into the function.
tool: Tools used for editing. "editor" allows a full screen editor (and is the default if source is "upload" or "webcam"), "select" provides a cropping and zoom tool, "sketch" allows you to create a binary sketch (and is the default if source="canvas"), and "color-sketch" allows you to created a sketch in different colors. "color-sketch" can be used with source="upload" or "webcam" to allow sketching on an image. "sketch" can also be used with "upload" or "webcam" to create a mask over an image and in that case both the image and mask are passed into the function as a dictionary with keys "image" and "mask" respectively.
type: The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "file" produces a temporary file object whose path can be retrieved by file_obj.name, "filepath" passes a str path to a temporary file containing the image.
label: component name in interface.
show_label: if True, will display label.
Expand All @@ -1228,7 +1229,10 @@ def __init__(
self.image_mode = image_mode
self.source = source
requires_permissions = source == "webcam"
self.tool = tool
if tool is None:
self.tool = "sketch" if source == "canvas" else "editor"
else:
self.tool = tool
self.invert_colors = invert_colors
self.test_input = deepcopy(media_data.BASE64_IMAGE)
self.interpret_by_tokens = True
Expand Down Expand Up @@ -1279,9 +1283,10 @@ def update(
return IOComponent.add_interactive_to_config(updated_config, interactive)

def _format_image(
self, im: Optional[PIL.Image], fmt: str
self, im: Optional[PIL.Image]
) -> np.array | PIL.Image | str | None:
"""Helper method to format an image based on self.type"""
fmt = im.format
if im is None:
return im
if self.type == "pil":
Expand Down Expand Up @@ -1314,36 +1319,37 @@ def generate_sample(self):
def preprocess(self, x: str | Dict) -> np.array | PIL.Image | str | None:
"""
Parameters:
x: base64 url data, or (if tool == "sketch) a dict of image and mask base64 url data
x: base64 url data, or (if tool == "sketch") a dict of image and mask base64 url data
Returns:
image in requested format
image in requested format, or (if tool == "sketch") a dict of image and mask in requested format
"""
if x is None:
return x
if self.tool == "sketch":
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
x, mask = x["image"], x["mask"]

im = processing_utils.decode_base64_to_image(x)
fmt = im.format
with warnings.catch_warnings():
warnings.simplefilter("ignore")
im = im.convert(self.image_mode)
if self.shape is not None:
im = processing_utils.resize_and_crop(im, self.shape)
if self.invert_colors:
im = PIL.ImageOps.invert(im)
if self.source == "webcam" and self.mirror_webcam is True:
if (
self.source == "webcam"
and self.mirror_webcam is True
and self.tool != "color-sketch"
):
im = PIL.ImageOps.mirror(im)

if not (self.tool == "sketch"):
return self._format_image(im, fmt)
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
mask_im = processing_utils.decode_base64_to_image(mask)
return {
"image": self._format_image(im),
"mask": self._format_image(mask_im),
}

mask_im = processing_utils.decode_base64_to_image(mask)
mask_fmt = mask_im.format
return {
"image": self._format_image(im, fmt),
"mask": self._format_image(mask_im, mask_fmt),
}
return self._format_image(im)

def postprocess(self, y: np.ndarray | PIL.Image | str | Path) -> str:
"""
Expand Down
40 changes: 39 additions & 1 deletion gradio/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Webcam(components.Image):
is_template = True

def __init__(self, **kwargs):
super().__init__(source="webcam", **kwargs)
super().__init__(source="webcam", interactive=True, **kwargs)


class Sketchpad(components.Image):
Expand All @@ -47,10 +47,48 @@ def __init__(self, **kwargs):
source="canvas",
shape=(28, 28),
invert_colors=True,
interactive=True,
**kwargs
)


class Paint(components.Image):
"""
Sets source="canvas", tool="color-sketch"
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(
source="canvas", tool="color-sketch", interactive=True, **kwargs
)


class ImageMask(components.Image):
"""
Sets source="canvas", tool="sketch"
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)


class ImagePaint(components.Image):
abidlabs marked this conversation as resolved.
Show resolved Hide resolved
"""
Sets source="upload", tool="color-sketch"
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(
source="upload", tool="color-sketch", interactive=True, **kwargs
)


class Pil(components.Image):
"""
Sets: type="pil"
Expand Down
2 changes: 1 addition & 1 deletion gradio/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.3.1
3.4b0
4 changes: 2 additions & 2 deletions ui/packages/app/test/blocks_xray.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ test("can run an api request and display the data", async ({ page }) => {
await page.check("label:has-text('Covid')");
await page.check("label:has-text('Lung Cancer')");

const run_button = await page.locator("button", { hasText: /Run/ });
const run_button = await page.locator("button", { hasText: /Run/ }).first();

await Promise.all([
run_button.click(),
page.waitForResponse("**/api/predict/")
]);

const json = await page.locator("data-testid=json");
const json = await page.locator("data-testid=json").first();
await expect(json).toContainText(`Covid: 0.75, Lung Cancer: 0.25`);
});
9 changes: 9 additions & 0 deletions ui/packages/icons/src/Brush.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<svg width="100%" height="100%" viewBox="0 0 32 32"
><path
d="M28.828 3.172a4.094 4.094 0 0 0-5.656 0L4.05 22.292A6.954 6.954 0 0 0 2 27.242V30h2.756a6.952 6.952 0 0 0 4.95-2.05L28.828 8.829a3.999 3.999 0 0 0 0-5.657zM10.91 18.26l2.829 2.829l-2.122 2.121l-2.828-2.828zm-2.619 8.276A4.966 4.966 0 0 1 4.756 28H4v-.759a4.967 4.967 0 0 1 1.464-3.535l1.91-1.91l2.829 2.828zM27.415 7.414l-12.261 12.26l-2.829-2.828l12.262-12.26a2.047 2.047 0 0 1 2.828 0a2 2 0 0 1 0 2.828z"
fill="currentColor"
/><path
d="M6.5 15a3.5 3.5 0 0 1-2.475-5.974l3.5-3.5a1.502 1.502 0 0 0 0-2.121a1.537 1.537 0 0 0-2.121 0L3.415 5.394L2 3.98l1.99-1.988a3.585 3.585 0 0 1 4.95 0a3.504 3.504 0 0 1 0 4.949L5.439 10.44a1.502 1.502 0 0 0 0 2.121a1.537 1.537 0 0 0 2.122 0l4.024-4.024L13 9.95l-4.025 4.024A3.475 3.475 0 0 1 6.5 15z"
fill="currentColor"
/></svg
>
Loading