-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add bynary stream support based on #421
- Loading branch information
1 parent
645afeb
commit c287d1c
Showing
1 changed file
with
161 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import asyncio | ||
import io | ||
import json | ||
import os | ||
import sys | ||
from typing import IO | ||
|
||
import click | ||
from PIL import Image | ||
|
||
from ..bg import remove | ||
from ..session_factory import new_session | ||
from ..sessions import sessions_names | ||
|
||
|
||
@click.command( | ||
name="b", | ||
help="for a byte stream as input", | ||
) | ||
@click.option( | ||
"-m", | ||
"--model", | ||
default="u2net", | ||
type=click.Choice(sessions_names), | ||
show_default=True, | ||
show_choices=True, | ||
help="model name", | ||
) | ||
@click.option( | ||
"-a", | ||
"--alpha-matting", | ||
is_flag=True, | ||
show_default=True, | ||
help="use alpha matting", | ||
) | ||
@click.option( | ||
"-af", | ||
"--alpha-matting-foreground-threshold", | ||
default=240, | ||
type=int, | ||
show_default=True, | ||
help="trimap fg threshold", | ||
) | ||
@click.option( | ||
"-ab", | ||
"--alpha-matting-background-threshold", | ||
default=10, | ||
type=int, | ||
show_default=True, | ||
help="trimap bg threshold", | ||
) | ||
@click.option( | ||
"-ae", | ||
"--alpha-matting-erode-size", | ||
default=10, | ||
type=int, | ||
show_default=True, | ||
help="erode size", | ||
) | ||
@click.option( | ||
"-om", | ||
"--only-mask", | ||
is_flag=True, | ||
show_default=True, | ||
help="output only the mask", | ||
) | ||
@click.option( | ||
"-ppm", | ||
"--post-process-mask", | ||
is_flag=True, | ||
show_default=True, | ||
help="post process the mask", | ||
) | ||
@click.option( | ||
"-bgc", | ||
"--bgcolor", | ||
default=None, | ||
type=(int, int, int, int), | ||
nargs=4, | ||
help="Background color (R G B A) to replace the removed background with", | ||
) | ||
@click.option("-x", "--extras", type=str) | ||
@click.option( | ||
"-o", | ||
"--output_specifier", | ||
type=str, | ||
help="printf-style specifier for output filenames (e.g. 'output-%d.png'))", | ||
) | ||
@click.argument( | ||
"image_width", | ||
type=int, | ||
) | ||
@click.argument( | ||
"image_height", | ||
type=int, | ||
) | ||
def rs_command( | ||
model: str, | ||
extras: str, | ||
image_width: int, | ||
image_height: int, | ||
output_specifier: str, | ||
**kwargs | ||
) -> None: | ||
try: | ||
kwargs.update(json.loads(extras)) | ||
except Exception: | ||
pass | ||
|
||
session = new_session(model) | ||
bytes_per_img = image_width * image_height * 3 | ||
|
||
if output_specifier: | ||
output_dir = os.path.dirname( | ||
os.path.abspath(os.path.expanduser(output_specifier)) | ||
) | ||
|
||
if not os.path.isdir(output_dir): | ||
os.makedirs(output_dir, exist_ok=True) | ||
|
||
def img_to_byte_array(img: Image) -> bytes: | ||
buff = io.BytesIO() | ||
img.save(buff, format="PNG") | ||
return buff.getvalue() | ||
|
||
async def connect_stdin_stdout(): | ||
loop = asyncio.get_event_loop() | ||
reader = asyncio.StreamReader() | ||
protocol = asyncio.StreamReaderProtocol(reader) | ||
|
||
await loop.connect_read_pipe(lambda: protocol, sys.stdin) | ||
w_transport, w_protocol = await loop.connect_write_pipe( | ||
asyncio.streams.FlowControlMixin, sys.stdout | ||
) | ||
|
||
writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop) | ||
return reader, writer | ||
|
||
async def main(): | ||
reader, writer = await connect_stdin_stdout() | ||
|
||
idx = 0 | ||
while True: | ||
try: | ||
img_bytes = await reader.readexactly(bytes_per_img) | ||
if not img_bytes: | ||
break | ||
|
||
img = Image.frombytes("RGB", (image_width, image_height), img_bytes) | ||
output = remove(img, session=session, **kwargs) | ||
|
||
if output_specifier: | ||
output.save((output_specifier % idx), format="PNG") | ||
else: | ||
writer.write(img_to_byte_array(output)) | ||
|
||
idx += 1 | ||
except asyncio.IncompleteReadError: | ||
break | ||
|
||
asyncio.run(main()) |