Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sketching + Inpainting Capabilities to Gradio #2144

Merged
merged 22 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 84 additions & 17 deletions demo/blocks_mask/run.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,93 @@
import gradio as gr
import os
from gradio.components import Markdown as md

def fn(mask):
return [mask["image"], mask["mask"]]
demo = gr.Blocks()

io1a = gr.Interface(lambda x: x, gr.Image(), gr.Image())
io1b = gr.Interface(lambda x: x, gr.Image(source='webcam'), gr.Image())

demo = gr.Blocks()
io2a = gr.Interface(lambda x: x, gr.Image(source='canvas'), gr.Image())
io2b = gr.Interface(lambda x: x, gr.Sketchpad(), gr.Image())

with demo:
with gr.Row():
with gr.Column():
img = gr.Image(
tool="sketch", source="upload", label="Mask", value=os.path.join(os.path.dirname(__file__), "lion.jpg")
)
with gr.Row():
btn = gr.Button("Run")
with gr.Column():
img2 = gr.Image()
img3 = gr.Image()

btn.click(fn=fn, inputs=img, outputs=[img2, img3])
io3a = gr.Interface(
lambda x: [x['mask'], x['image']],
gr.Image(source='upload', tool='sketch'),
[gr.Image(), gr.Image()],
)

io3b = gr.Interface(
lambda x: [x['mask'], x['image']],
gr.ImageMask(),
[gr.Image(), gr.Image()],
)

io3c = gr.Interface(
lambda x: [x['mask'], x['image']],
gr.Image(source='webcam', tool='sketch'),
[gr.Image(), gr.Image()],
)

io4a = gr.Interface(lambda x: x, gr.Image(source='canvas', tool='color-sketch'), gr.Image())
io4b = gr.Interface(lambda x: x, gr.Paint(), gr.Image())

io5a = gr.Interface(lambda x: x, gr.Image(source='upload', tool='color-sketch'), gr.Image())
io5b = gr.Interface(lambda x: x, gr.ImagePaint(), gr.Image())
io5c = gr.Interface(lambda x: x, gr.Image(source='webcam', tool='color-sketch'), gr.Image())


with demo:
md("# Different Ways to Use the Image Input Component")
md("**1a. Standalone Image Upload: `gr.Interface(lambda x: x, gr.Image(), gr.Image())`**")
io1a.render()
md("**1b. Standalone Image from Webcam: `gr.Interface(lambda x: x, gr.Image(source='webcam'), gr.Image())`**")
io1b.render()
md("**2a. Black and White Sketchpad: `gr.Interface(lambda x: x, gr.Image(source='canvas'), gr.Image())`**")
io2a.render()
md("**2b. Black and White Sketchpad: `gr.Interface(lambda x: x, gr.Sketchpad(), gr.Image())`**")
io2b.render()
md("**3a. Binary Mask with image upload:**")
md(
"""```python
gr.Interface(
lambda x: [x['mask'], x['image']],
gr.Image(source='upload', tool='sketch'),
[gr.Image(), gr.Image()],
)
```
""")
io3a.render()
md("**3b. Binary Mask with image upload:**")
md(
"""```python
gr.Interface(
lambda x: [x['mask'], x['image']],
gr.ImageMask(),
[gr.Image(), gr.Image()],
)
```
""")
io3b.render()
md("**3c. Binary Mask with webcam upload:**")
md(
"""```python
gr.Interface(
lambda x: [x['mask'], x['image']],
gr.Image(source='webcam', tool='sketch'),
[gr.Image(), gr.Image()],
)
```
""")
io3c.render()
md("**4a. Color Sketchpad: `gr.Interface(lambda x: x, gr.Image(source='canvas', tool='color-sketch'), gr.Image())`**")
io4a.render()
md("**4b. Color Sketchpad: `gr.Interface(lambda x: x, gr.Paint(), gr.Image())`**")
io4b.render()
md("**5a. Color Sketchpad with image upload: `gr.Interface(lambda x: x, gr.Image(source='upload', tool='color-sketch'), gr.Image())`**")
io5a.render()
md("**5b. Color Sketchpad with image upload: `gr.Interface(lambda x: x, gr.ImagePaint(), gr.Image())`**")
io5b.render()
md("**5c. Color Sketchpad with webcam upload: `gr.Interface(lambda x: x, gr.Image(source='webcam', tool='color-sketch'), gr.Image())`**")
io5c.render()

if __name__ == "__main__":
demo.launch()
2 changes: 1 addition & 1 deletion demo/filter_records/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def filter_records(records, gender):
headers=["name", "age", "gender"],
datatype=["str", "number", "str"],
row_count=5,
col_count=(3, "fixed")
col_count=(3, "fixed"),
),
gr.Dropdown(["M", "F", "O"]),
],
Expand Down
3 changes: 3 additions & 0 deletions gradio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@
from gradio.templates import (
Files,
Highlight,
ImageMask,
ImagePaint,
List,
Matrix,
Mic,
Microphone,
Numpy,
Paint,
Pil,
PlayableVideo,
Sketchpad,
Expand Down
39 changes: 21 additions & 18 deletions gradio/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DataframeData(TypedDict):
import numpy as np
import pandas as pd
import PIL
import PIL.ImageOps
from ffmpy import FFmpeg
from markdown_it import MarkdownIt

Expand Down Expand Up @@ -1175,11 +1176,11 @@ def style(
)


@document("edit", "clear", "change", "stream", "change")
@document("edit", "clear", "change", "stream", "change", "style")
class Image(Editable, Clearable, Changeable, Streamable, IOComponent, ImgSerializable):
"""
Creates an image component that can be used to upload/draw images (as an input) or display images (as an output).
Preprocessing: passes the uploaded image as a {numpy.array}, {PIL.Image} or {str} filepath depending on `type` -- unless `tool` is `sketch`. In the special case, a {dict} with keys `image` and `mask` is passed, and the format of the corresponding values depends on `type`.
Preprocessing: passes the uploaded image as a {numpy.array}, {PIL.Image} or {str} filepath depending on `type` -- unless `tool` is `sketch` AND source is one of `upload` or `webcam`. In these cases, a {dict} with keys `image` and `mask` is passed, and the format of the corresponding values depends on `type`.
Postprocessing: expects a {numpy.array}, {PIL.Image} or {str} or {pathlib.Path} filepath to an image and displays the image.
Examples-format: a {str} filepath to a local file that contains the image.
Demos: image_mod, image_mod_default_image
Expand All @@ -1194,7 +1195,7 @@ def __init__(
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "upload",
tool: str = "editor",
tool: str = None,
type: str = "numpy",
label: Optional[str] = None,
show_label: bool = True,
Expand All @@ -1212,7 +1213,7 @@ def __init__(
image_mode: "RGB" if color, or "L" if black and white.
invert_colors: whether to invert the image as a preprocessing step.
source: Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
tool: Tools used for editing. "editor" allows a full screen editor, "select" provides a cropping and zoom tool, "sketch" allows you to create a mask over the image and both the image and mask are passed into the function.
tool: Tools used for editing. "editor" allows a full screen editor (and is the default if source is "upload" or "webcam"), "select" provides a cropping and zoom tool, "sketch" allows you to create a binary sketch (and is the default if source="canvas"), and "color-sketch" allows you to created a sketch in different colors. "color-sketch" can be used with source="upload" or "webcam" to allow sketching on an image. "sketch" can also be used with "upload" or "webcam" to create a mask over an image and in that case both the image and mask are passed into the function as a dictionary with keys "image" and "mask" respectively.
type: The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "file" produces a temporary file object whose path can be retrieved by file_obj.name, "filepath" passes a str path to a temporary file containing the image.
label: component name in interface.
show_label: if True, will display label.
Expand All @@ -1228,7 +1229,10 @@ def __init__(
self.image_mode = image_mode
self.source = source
requires_permissions = source == "webcam"
self.tool = tool
if tool is None:
self.tool = "sketch" if source == "canvas" else "editor"
else:
self.tool = tool
self.invert_colors = invert_colors
self.test_input = deepcopy(media_data.BASE64_IMAGE)
self.interpret_by_tokens = True
Expand Down Expand Up @@ -1279,9 +1283,10 @@ def update(
return IOComponent.add_interactive_to_config(updated_config, interactive)

def _format_image(
self, im: Optional[PIL.Image], fmt: str
self, im: Optional[PIL.Image]
) -> np.array | PIL.Image | str | None:
"""Helper method to format an image based on self.type"""
fmt = im.format
if im is None:
return im
if self.type == "pil":
Expand Down Expand Up @@ -1314,17 +1319,16 @@ def generate_sample(self):
def preprocess(self, x: str | Dict) -> np.array | PIL.Image | str | None:
"""
Parameters:
x: base64 url data, or (if tool == "sketch) a dict of image and mask base64 url data
x: base64 url data, or (if tool == "sketch") a dict of image and mask base64 url data
Returns:
image in requested format
image in requested format, or (if tool == "sketch") a dict of image and mask in requested format
"""
if x is None:
return x
if self.tool == "sketch":
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
x, mask = x["image"], x["mask"]

im = processing_utils.decode_base64_to_image(x)
fmt = im.format
with warnings.catch_warnings():
warnings.simplefilter("ignore")
im = im.convert(self.image_mode)
Expand All @@ -1335,15 +1339,14 @@ def preprocess(self, x: str | Dict) -> np.array | PIL.Image | str | None:
if self.source == "webcam" and self.mirror_webcam is True:
im = PIL.ImageOps.mirror(im)

if not (self.tool == "sketch"):
return self._format_image(im, fmt)
if self.tool == "sketch" and self.source in ["upload", "webcam"]:
mask_im = processing_utils.decode_base64_to_image(mask)
return {
"image": self._format_image(im),
"mask": self._format_image(mask_im),
}

mask_im = processing_utils.decode_base64_to_image(mask)
mask_fmt = mask_im.format
return {
"image": self._format_image(im, fmt),
"mask": self._format_image(mask_im, mask_fmt),
}
return self._format_image(im)

def postprocess(self, y: np.ndarray | PIL.Image | str | Path) -> str:
"""
Expand Down
40 changes: 39 additions & 1 deletion gradio/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Webcam(components.Image):
is_template = True

def __init__(self, **kwargs):
super().__init__(source="webcam", **kwargs)
super().__init__(source="webcam", interactive=True, **kwargs)


class Sketchpad(components.Image):
Expand All @@ -47,10 +47,48 @@ def __init__(self, **kwargs):
source="canvas",
shape=(28, 28),
invert_colors=True,
interactive=True,
**kwargs
)


class Paint(components.Image):
"""
Sets source="canvas", tool="color-sketch"
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(
source="canvas", tool="color-sketch", interactive=True, **kwargs
)


class ImageMask(components.Image):
"""
Sets source="canvas", tool="sketch"
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)


class ImagePaint(components.Image):
"""
Sets source="upload", tool="color-sketch"
"""

is_template = True

def __init__(self, **kwargs):
super().__init__(
source="upload", tool="color-sketch", interactive=True, **kwargs
)


class Pil(components.Image):
"""
Sets: type="pil"
Expand Down
2 changes: 1 addition & 1 deletion gradio/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.3.1
3.4b0
9 changes: 9 additions & 0 deletions ui/packages/icons/src/Brush.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<svg width="100%" height="100%" viewBox="0 0 32 32"
><path
d="M28.828 3.172a4.094 4.094 0 0 0-5.656 0L4.05 22.292A6.954 6.954 0 0 0 2 27.242V30h2.756a6.952 6.952 0 0 0 4.95-2.05L28.828 8.829a3.999 3.999 0 0 0 0-5.657zM10.91 18.26l2.829 2.829l-2.122 2.121l-2.828-2.828zm-2.619 8.276A4.966 4.966 0 0 1 4.756 28H4v-.759a4.967 4.967 0 0 1 1.464-3.535l1.91-1.91l2.829 2.828zM27.415 7.414l-12.261 12.26l-2.829-2.828l12.262-12.26a2.047 2.047 0 0 1 2.828 0a2 2 0 0 1 0 2.828z"
fill="currentColor"
/><path
d="M6.5 15a3.5 3.5 0 0 1-2.475-5.974l3.5-3.5a1.502 1.502 0 0 0 0-2.121a1.537 1.537 0 0 0-2.121 0L3.415 5.394L2 3.98l1.99-1.988a3.585 3.585 0 0 1 4.95 0a3.504 3.504 0 0 1 0 4.949L5.439 10.44a1.502 1.502 0 0 0 0 2.121a1.537 1.537 0 0 0 2.122 0l4.024-4.024L13 9.95l-4.025 4.024A3.475 3.475 0 0 1 6.5 15z"
fill="currentColor"
/></svg
>
25 changes: 8 additions & 17 deletions ui/packages/icons/src/Chart.svelte
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
<svg
xmlns="http://www.w3.org/2000/svg"
width="100%"
height="100%"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
stroke-linecap="round"
stroke-linejoin="round"
class="feather feather-bar-chart-2"
><line x1="18" y1="20" x2="18" y2="10" /><line
x1="12"
y1="20"
x2="12"
y2="4"
/><line x1="6" y1="20" x2="6" y2="14" /></svg
<svg width="1em" height="1em" viewBox="0 0 32 32"
><path
d="M28.828 3.172a4.094 4.094 0 0 0-5.656 0L4.05 22.292A6.954 6.954 0 0 0 2 27.242V30h2.756a6.952 6.952 0 0 0 4.95-2.05L28.828 8.829a3.999 3.999 0 0 0 0-5.657zM10.91 18.26l2.829 2.829l-2.122 2.121l-2.828-2.828zm-2.619 8.276A4.966 4.966 0 0 1 4.756 28H4v-.759a4.967 4.967 0 0 1 1.464-3.535l1.91-1.91l2.829 2.828zM27.415 7.414l-12.261 12.26l-2.829-2.828l12.262-12.26a2.047 2.047 0 0 1 2.828 0a2 2 0 0 1 0 2.828z"
fill="currentColor"
/><path
d="M6.5 15a3.5 3.5 0 0 1-2.475-5.974l3.5-3.5a1.502 1.502 0 0 0 0-2.121a1.537 1.537 0 0 0-2.121 0L3.415 5.394L2 3.98l1.99-1.988a3.585 3.585 0 0 1 4.95 0a3.504 3.504 0 0 1 0 4.949L5.439 10.44a1.502 1.502 0 0 0 0 2.121a1.537 1.537 0 0 0 2.122 0l4.024-4.024L13 9.95l-4.025 4.024A3.475 3.475 0 0 1 6.5 15z"
fill="currentColor"
/></svg
>
16 changes: 16 additions & 0 deletions ui/packages/icons/src/Color.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
<svg width="100%" height="100%" viewBox="0 0 32 32"
><circle cx="10" cy="12" r="2" fill="currentColor" /><circle
cx="16"
cy="9"
r="2"
fill="currentColor"
/><circle cx="22" cy="12" r="2" fill="currentColor" /><circle
cx="23"
cy="18"
r="2"
fill="currentColor"
/><circle cx="19" cy="23" r="2" fill="currentColor" /><path
fill="currentColor"
d="M16.54 2A14 14 0 0 0 2 16a4.82 4.82 0 0 0 6.09 4.65l1.12-.31a3 3 0 0 1 3.79 2.9V27a3 3 0 0 0 3 3a14 14 0 0 0 14-14.54A14.05 14.05 0 0 0 16.54 2Zm8.11 22.31A11.93 11.93 0 0 1 16 28a1 1 0 0 1-1-1v-3.76a5 5 0 0 0-5-5a5.07 5.07 0 0 0-1.33.18l-1.12.31A2.82 2.82 0 0 1 4 16A12 12 0 0 1 16.47 4A12.18 12.18 0 0 1 28 15.53a11.89 11.89 0 0 1-3.35 8.79Z"
/></svg
>
30 changes: 16 additions & 14 deletions ui/packages/icons/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
export { default as Clear } from "./Clear.svelte";
export { default as Brush } from "./Brush.svelte";
export { default as Camera } from "./Camera.svelte";
export { default as Chart } from "./Chart.svelte";
export { default as Chat } from "./Chat.svelte";
export { default as Circle } from "./Circle.svelte";
export { default as Clear } from "./Clear.svelte";
export { default as Color } from "./Color.svelte";
export { default as Edit } from "./Edit.svelte";
export { default as File } from "./File.svelte";
export { default as Image } from "./Image.svelte";
export { default as JSON } from "./JSON.svelte";
export { default as LineChart } from "./LineChart.svelte";
export { default as Maximise } from "./Maximise.svelte";
export { default as Music } from "./Music.svelte";
export { default as Pause } from "./Pause.svelte";
export { default as Play } from "./Play.svelte";
export { default as Plot } from "./Plot.svelte";
export { default as Sketch } from "./Sketch.svelte";
export { default as Square } from "./Square.svelte";
export { default as Table } from "./Table.svelte";
export { default as Undo } from "./Undo.svelte";
export { default as Video } from "./Video.svelte";
export { default as Image } from "./Image.svelte";
export { default as Chart } from "./Chart.svelte";
export { default as Music } from "./Music.svelte";
export { default as File } from "./File.svelte";
export { default as LineChart } from "./LineChart.svelte";
export { default as TextHighlight } from "./TextHighlight.svelte";
export { default as JSON } from "./JSON.svelte";
export { default as Tree } from "./Tree.svelte";
export { default as Chat } from "./Chat.svelte";
export { default as Plot } from "./Plot.svelte";
export { default as Play } from "./Play.svelte";
export { default as Pause } from "./Pause.svelte";
export { default as Maximise } from "./Maximise.svelte";
export { default as Undo } from "./Undo.svelte";
export { default as Video } from "./Video.svelte";
Loading