diff --git a/pipelines/examples/text_to_image/README.md b/pipelines/examples/text_to_image/README.md new file mode 100644 index 000000000000..23512c40ef23 --- /dev/null +++ b/pipelines/examples/text_to_image/README.md @@ -0,0 +1,114 @@ +# ERNIE-ViLG 文生图系统 + +## 1. 场景概述 + +ERNIE-ViLG是一个知识增强跨模态图文生成大模型,将文生成图和图生成文任务融合到同一个模型进行端到端的学习,从而实现文本和图像的跨模态语义对齐。可以支持用户进行内容创作,让每个用户都能够体验到一个低门槛的创作平台。更多详细信息请参考官网的介绍[ernieVilg](https://wenxin.baidu.com/moduleApi/ernieVilg) + + +## 2. 产品功能介绍 + +本项目提供了低成本搭建端到端文生图的能力。用户需要进行简单的参数配置,然后输入prompts就可以生成各种风格的画作,另外,Pipelines提供了 Web 化产品服务,让用户在本地端就能搭建起来文生图系统。 + + +## 3. 快速开始: 快速搭建文生图系统 + + +### 3.1 运行环境和安装说明 + +本实验采用了以下的运行环境进行,详细说明如下,用户也可以在自己的环境进行: + +a. 软件环境: +- python >= 3.7.0 +- paddlenlp >= 2.4.0 +- paddlepaddle-gpu >=2.3 +- CUDA Version: 10.2 +- NVIDIA Driver Version: 440.64.00 +- Ubuntu 16.04.6 LTS (Docker) + +b. 硬件环境: + +- NVIDIA Tesla V100 16GB x4卡 +- Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz + +c. 依赖安装: +首先需要安装PaddlePaddle,PaddlePaddle的安装请参考文档[官方安装文档](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html),然后安装下面的依赖: +```bash +# pip 一键安装 +pip install --upgrade paddle-pipelines -i https://pypi.tuna.tsinghua.edu.cn/simple +# 或者源码进行安装最新版本 +cd ${HOME}/PaddleNLP/pipelines/ +pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple +python setup.py install +``` +【注意】以下的所有的流程都只需要在`pipelines`根目录下进行,不需要跳转目录;另外,文生图系统需要联网,用户需要在有网的环境下进行。 + + +### 3.2 一键体验文生图系统 + +在运行下面的命令之前,需要在[ERNIE-ViLG官网](https://wenxin.baidu.com/moduleApi/ernieVilg)申请`API Key`和 `Secret key`两个密钥(需要登录,登录后点击右上角的查看AK/SK,具体如下图),然后执行下面的命令。 + +
+ +
+ + +#### 3.2.1 快速一键启动 + +您可以通过如下命令快速体验文生图系统的效果 +```bash +python examples/text_to_image/text_to_image_example.py --prompt_text 宁静的小镇 \ + --style 古风 \ + --topk 5 \ + --api_key 你申请的apikey \ + --secret_key 你申请的secretkey \ + --output_dir ernievilg_output +``` +大概运行一分钟后就可以得到结果了,生成的图片请查看您的输出目录`output_dir`。 + +### 3.3 构建 Web 可视化文生图系统 + +整个 Web 可视化文生图系统主要包含 2 大组件: 1. 基于 RestfulAPI 构建模型服务 2. 基于 Gradio 构建 WebUI,接下来我们依次搭建这 2 个服务并最终形成可视化的文生图系统。 + +#### 3.3.1 启动 RestAPI 模型服务 + +启动之前,需要把您申请的`API Key`和 `Secret key`两个密钥添加到`text_to_image.yaml`的ak和sk的位置,然后运行: + +```bash +export PIPELINE_YAML_PATH=rest_api/pipeline/text_to_image.yaml +# 使用端口号 8891 启动模型服务 +python rest_api/application.py 8891 +``` +Linux 用户推荐采用 Shell 脚本来启动服务:: + +```bash +sh examples/text_to_image/run_text_to_image.sh +``` + +#### 3.3.2 启动 WebUI + +WebUI使用了[gradio前端](https://gradio.app/),首先需要安装gradio,运行命令如下: +``` +pip install gradio +``` +然后使用如下的命令启动: +```bash +# 配置模型服务地址 +export API_ENDPOINT=http://127.0.0.1:8891 +# 在指定端口 8502 启动 WebUI +python ui/webapp_text_to_image.py --serving_port 8502 +``` +Linux 用户推荐采用 Shell 脚本来启动服务:: + +```bash +sh examples/text_to_image/run_text_to_image_web.sh +``` + +到这里您就可以打开浏览器访问 http://127.0.0.1:8502 地址体验文生图系统服务了。 + +如果安装遇见问题可以查看[FAQ文档](../../FAQ.md) + +## Acknowledge + +我们借鉴了 Deepset.ai [Haystack](https://github.com/deepset-ai/haystack) 优秀的框架设计,在此对[Haystack](https://github.com/deepset-ai/haystack)作者及其开源社区表示感谢。 + +We learn form the excellent framework design of Deepset.ai [Haystack](https://github.com/deepset-ai/haystack), and we would like to express our thanks to the authors of Haystack and their open source community. diff --git a/pipelines/examples/text_to_image/run_text_to_image.sh b/pipelines/examples/text_to_image/run_text_to_image.sh new file mode 100644 index 000000000000..4a61f0b98e9e --- /dev/null +++ b/pipelines/examples/text_to_image/run_text_to_image.sh @@ -0,0 +1,19 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# 指定文生图的Yaml配置文件 +unset http_proxy && unset https_proxy +export PIPELINE_YAML_PATH=rest_api/pipeline/text_to_image.yaml +# 使用端口号 8891 启动模型服务 +python rest_api/application.py 8891 \ No newline at end of file diff --git a/pipelines/examples/text_to_image/run_text_to_image_web.sh b/pipelines/examples/text_to_image/run_text_to_image_web.sh new file mode 100644 index 000000000000..05a59f7be69f --- /dev/null +++ b/pipelines/examples/text_to_image/run_text_to_image_web.sh @@ -0,0 +1,18 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# 配置模型服务地址 +export API_ENDPOINT=http://127.0.0.1:8891 +# 在指定端口 8502 启动 WebUI +python ui/webapp_text_to_image.py --serving_port 8502 \ No newline at end of file diff --git a/pipelines/examples/text_to_image/text_to_image_example.py b/pipelines/examples/text_to_image/text_to_image_example.py new file mode 100644 index 000000000000..8637b1a52aa4 --- /dev/null +++ b/pipelines/examples/text_to_image/text_to_image_example.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse + +import paddle +from pipelines.nodes import ErnieTextToImageGenerator +from pipelines import TextToImagePipeline + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument("--api_key", default=None, type=str, help="The API Key.") +parser.add_argument("--secret_key", default=None, type=str, help="The secret key.") +parser.add_argument("--prompt_text", default='宁静的小镇', type=str, help="The prompt_text.") +parser.add_argument("--output_dir", default='ernievilg_output', type=str, help="The output path.") +parser.add_argument("--style", default='探索无限', type=str, help="The style text.") +parser.add_argument("--size", default='1024*1024', + choices=['1024*1024', '1024*1536', '1536*1024'], help="Size of the generation images") +parser.add_argument("--topk", default=5, type=int, help="The top k images.") +args = parser.parse_args() +# yapf: enable + + +def text_to_image(): + erine_image_generator = ErnieTextToImageGenerator(ak=args.api_key, + sk=args.secret_key) + pipe = TextToImagePipeline(erine_image_generator) + prediction = pipe.run(query=args.prompt_text, + params={ + "TextToImageGenerator": { + "topk": args.topk, + "style": args.style, + "resolution": args.size, + "output_dir": args.output_dir + } + }) + pipe.save_to_yaml('text_to_image.yaml') + + +if __name__ == "__main__": + text_to_image() diff --git a/pipelines/pipelines/__init__.py b/pipelines/pipelines/__init__.py index 83dc75fcf6b2..ede1272c1316 100644 --- a/pipelines/pipelines/__init__.py +++ b/pipelines/pipelines/__init__.py @@ -39,7 +39,8 @@ from pipelines.pipelines import Pipeline from pipelines.pipelines.standard_pipelines import (BaseStandardPipeline, ExtractiveQAPipeline, - SemanticSearchPipeline) + SemanticSearchPipeline, + TextToImagePipeline) import pandas as pd diff --git a/pipelines/pipelines/nodes/__init__.py b/pipelines/pipelines/nodes/__init__.py index a4285acaaf47..a56fd1ccbcfd 100644 --- a/pipelines/pipelines/nodes/__init__.py +++ b/pipelines/pipelines/nodes/__init__.py @@ -29,3 +29,4 @@ from pipelines.nodes.ranker import BaseRanker, ErnieRanker from pipelines.nodes.reader import BaseReader, ErnieReader from pipelines.nodes.retriever import BaseRetriever, DensePassageRetriever +from pipelines.nodes.text_to_image_generator import ErnieTextToImageGenerator diff --git a/pipelines/pipelines/nodes/text_to_image_generator/__init__.py b/pipelines/pipelines/nodes/text_to_image_generator/__init__.py new file mode 100644 index 000000000000..579c485c01a0 --- /dev/null +++ b/pipelines/pipelines/nodes/text_to_image_generator/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pipelines.nodes.text_to_image_generator.text_to_image_generator import ErnieTextToImageGenerator diff --git a/pipelines/pipelines/nodes/text_to_image_generator/text_to_image_generator.py b/pipelines/pipelines/nodes/text_to_image_generator/text_to_image_generator.py new file mode 100644 index 000000000000..9a7bf2aa7389 --- /dev/null +++ b/pipelines/pipelines/nodes/text_to_image_generator/text_to_image_generator.py @@ -0,0 +1,266 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import requests +import hashlib +from io import BytesIO +from PIL import Image +from typing import List +from typing import Optional +from tqdm.auto import tqdm + +from pipelines.schema import Document +from pipelines.nodes.base import BaseComponent + + +class ErnieTextToImageGenerator(BaseComponent): + """ + ErnieTextToImageGenerator that uses a Ernie Vilg for text to image generation. + """ + + def __init__(self, ak=None, sk=None): + """ + :param ak: ak for applying token to request wenxin api. + :param sk: sk for applying token to request wenxin api. + """ + if (ak is None or sk is None): + raise Exception( + "Please apply api_key and secret_key from https://wenxin.baidu.com/moduleApi/ernieVilg" + ) + self.ak = ak + self.sk = sk + self.token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token' + self.token = self._apply_token(self.ak, self.sk) + + # save init parameters to enable export of component config as YAML + self.set_config( + ak=ak, + sk=sk, + ) + + def _apply_token(self, ak, sk): + if ak is None or sk is None: + ak = self.ak + sk = self.sk + response = requests.get(self.token_host, + params={ + 'grant_type': 'client_credentials', + 'client_id': ak, + 'client_secret': sk + }) + if response: + res = response.json() + if res['code'] != 0: + print('Request access token error.') + raise RuntimeError("Request access token error.") + else: + print('Request access token error.') + raise RuntimeError("Request access token error.") + return res['data'] + + def generate_image(self, + text_prompts, + style: Optional[str] = "探索无限", + resolution: Optional[str] = "1024*1024", + topk: Optional[int] = 6, + visualization: Optional[bool] = True, + output_dir: Optional[str] = 'ernievilg_output'): + """ + Create image by text prompts using ErnieVilG model. + :param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like. + :param style: Image stype, currently supported 古风、油画、水彩、卡通、二次元、浮世绘、蒸汽波艺术、 + low poly、像素风格、概念艺术、未来主义、赛博朋克、写实风格、洛丽塔风格、巴洛克风格、超现实主义、探索无限。 + :param resolution: Resolution of images, currently supported "1024*1024", "1024*1536", "1536*1024". + :param topk: Top k images to save. + :param visualization: Whether to save images or not. + :output_dir: Output directory + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + token = self.token + create_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/txt2img?from=paddlehub' + get_url = 'https://wenxin.baidu.com/younger/portal/api/rest/1.0/ernievilg/v1/getImg?from=paddlehub' + if isinstance(text_prompts, str): + text_prompts = [text_prompts] + taskids = [] + for text_prompt in text_prompts: + res = requests.post( + create_url, + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + data={ + 'access_token': token, + "text": text_prompt, + "style": style, + "resolution": resolution + }) + res = res.json() + if res['code'] == 4001: + print('请求参数错误') + raise RuntimeError("请求参数错误") + elif res['code'] == 4002: + print('请求参数格式错误,请检查必传参数是否齐全,参数类型等') + raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等") + elif res['code'] == 4003: + print('请求参数中,图片风格不在可选范围内') + raise RuntimeError("请求参数中,图片风格不在可选范围内") + elif res['code'] == 4004: + print('API服务内部错误,可能引起原因有请求超时、模型推理错误等') + raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等") + elif res['code'] == 100 or res['code'] == 110 or res['code'] == 111: + token = self._apply_token(self.ak, self.sk) + res = requests.post(create_url, + headers={ + 'Content-Type': + 'application/x-www-form-urlencoded' + }, + data={ + 'access_token': token, + "text": text_prompt, + "style": style, + "resolution": resolution + }) + res = res.json() + if res['code'] != 0: + print("Token失效重新请求后依然发生错误,请检查输入的参数") + raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数") + if res['msg'] == 'success': + taskids.append(res['data']["taskId"]) + else: + print(res['msg']) + raise RuntimeError(res['msg']) + + start_time = time.time() + process_bar = tqdm(total=100, unit='%') + results = {} + total_time = 60 * len(taskids) + while True: + end_time = time.time() + duration = end_time - start_time + progress_rate = int((duration) / total_time * 100) + if not taskids: + progress_rate = 100 + if progress_rate > process_bar.n: + if progress_rate >= 100: + if not taskids: + increase_rate = 100 - process_bar.n + else: + increase_rate = 0 + else: + increase_rate = progress_rate - process_bar.n + else: + increase_rate = 0 + process_bar.update(increase_rate) + if duration < 30: + time.sleep(5) + continue + else: + time.sleep(6) + if not taskids: + break + has_done = [] + for taskid in taskids: + res = requests.post(get_url, + headers={ + 'Content-Type': + 'application/x-www-form-urlencoded' + }, + data={ + 'access_token': token, + 'taskId': {taskid} + }) + res = res.json() + if res['code'] == 4001: + print('请求参数错误') + raise RuntimeError("请求参数错误") + elif res['code'] == 4002: + print('请求参数格式错误,请检查必传参数是否齐全,参数类型等') + raise RuntimeError("请求参数格式错误,请检查必传参数是否齐全,参数类型等") + elif res['code'] == 4003: + print('请求参数中,图片风格不在可选范围内') + raise RuntimeError("请求参数中,图片风格不在可选范围内") + elif res['code'] == 4004: + print('API服务内部错误,可能引起原因有请求超时、模型推理错误等') + raise RuntimeError("API服务内部错误,可能引起原因有请求超时、模型推理错误等") + elif res['code'] == 100 or res['code'] == 110 or res[ + 'code'] == 111: + token = self._apply_token(self.ak, self.sk) + res = requests.post(get_url, + headers={ + 'Content-Type': + 'application/x-www-form-urlencoded' + }, + data={ + 'access_token': token, + 'taskId': {taskid} + }) + res = res.json() + if res['code'] != 0: + print("Token失效重新请求后依然发生错误,请检查输入的参数") + raise RuntimeError("Token失效重新请求后依然发生错误,请检查输入的参数") + if res['msg'] == 'success': + if res['data']['status'] == 1: + has_done.append(res['data']['taskId']) + results[res['data']['text']] = { + 'imgUrls': res['data']['imgUrls'], + 'waiting': res['data']['waiting'], + 'taskId': res['data']['taskId'] + } + else: + print(res['msg']) + raise RuntimeError(res['msg']) + for taskid in has_done: + taskids.remove(taskid) + print('Saving Images...') + result_images = [] + for text, data in results.items(): + for idx, imgdata in enumerate(data['imgUrls']): + try: + image = Image.open( + BytesIO(requests.get(imgdata['image']).content)) + except Exception as e: + print('Download generated images error, retry one time') + try: + image = Image.open( + BytesIO(requests.get(imgdata['image']).content)) + except Exception: + raise RuntimeError('Download generated images failed.') + if visualization: + ext = 'png' + md5hash = hashlib.md5(image.tobytes()) + md5_name = md5hash.hexdigest() + image_name = '{}.{}'.format(md5_name, ext) + image_path = os.path.join(output_dir, image_name) + image.save(image_path) + result_images.append(image_path) + if idx + 1 >= topk: + break + print('Done') + return result_images + + def run(self, + query: Document, + style: Optional[str] = None, + topk: Optional[int] = None, + resolution: Optional[str] = "1024*1024", + output_dir: Optional[str] = 'ernievilg_output'): + + result_images = self.generate_image(query, + style=style, + topk=topk, + resolution=resolution, + output_dir=output_dir) + results = {"results": result_images} + return results, "output_1" diff --git a/pipelines/pipelines/pipelines/__init__.py b/pipelines/pipelines/pipelines/__init__.py index a1a25f53aa4b..04f1367033a6 100644 --- a/pipelines/pipelines/pipelines/__init__.py +++ b/pipelines/pipelines/pipelines/__init__.py @@ -15,4 +15,5 @@ from pipelines.pipelines.base import Pipeline, RootNode from pipelines.pipelines.standard_pipelines import (BaseStandardPipeline, ExtractiveQAPipeline, - SemanticSearchPipeline) + SemanticSearchPipeline, + TextToImagePipeline) diff --git a/pipelines/pipelines/pipelines/standard_pipelines.py b/pipelines/pipelines/pipelines/standard_pipelines.py index d459c33db7c2..2b2fc5cbe769 100644 --- a/pipelines/pipelines/pipelines/standard_pipelines.py +++ b/pipelines/pipelines/pipelines/standard_pipelines.py @@ -24,6 +24,7 @@ from pipelines.nodes.ranker import BaseRanker from pipelines.nodes.retriever import BaseRetriever from pipelines.document_stores import BaseDocumentStore +from pipelines.nodes.text_to_image_generator import ErnieTextToImageGenerator from pipelines.pipelines import Pipeline logger = logging.getLogger(__name__) @@ -263,3 +264,34 @@ def run(self, """ output = self.pipeline.run(query=query, params=params, debug=debug) return output + + +class TextToImagePipeline(BaseStandardPipeline): + """ + A simple pipeline that takes prompt texts as input and generates + images. + """ + + def __init__(self, text_to_image_generator: ErnieTextToImageGenerator): + self.pipeline = Pipeline() + self.pipeline.add_node(component=text_to_image_generator, + name="TextToImageGenerator", + inputs=["Query"]) + + def run(self, + query: str, + params: Optional[dict] = None, + debug: Optional[bool] = None): + output = self.pipeline.run(query=query, params=params, debug=debug) + return output + + def run_batch( + self, + documents: List[Document], + params: Optional[dict] = None, + debug: Optional[bool] = None, + ): + output = self.pipeline.run_batch(documents=documents, + params=params, + debug=debug) + return output diff --git a/pipelines/rest_api/controller/search.py b/pipelines/rest_api/controller/search.py index 29c7359e608b..780137440225 100644 --- a/pipelines/rest_api/controller/search.py +++ b/pipelines/rest_api/controller/search.py @@ -27,7 +27,7 @@ from pipelines.pipelines.base import Pipeline from rest_api.config import PIPELINE_YAML_PATH, QUERY_PIPELINE_NAME from rest_api.config import LOG_LEVEL, CONCURRENT_REQUEST_PER_WORKER -from rest_api.schema import QueryRequest, QueryResponse +from rest_api.schema import QueryRequest, QueryResponse, QueryImageResponse from rest_api.controller.utils import RequestLimiter logging.getLogger("pipelines").setLevel(LOG_LEVEL) @@ -81,6 +81,27 @@ def query(request: QueryRequest): return result +@router.post("/query_text_to_images", + response_model=QueryImageResponse, + response_model_exclude_none=True) +def query_images(request: QueryRequest): + """ + This endpoint receives the question as a string and allows the requester to set + additional parameters that will be passed on to the pipelines pipeline. + """ + result = {} + result['query'] = request.query + params = request.params or {} + res = PIPELINE.run(query=request.query, params=params, debug=request.debug) + # Ensure answers and documents exist, even if they're empty lists + result['answers'] = res['results'] + if not "documents" in result: + result["documents"] = [] + if not "answers" in result: + result["answers"] = [] + return result + + def _process_request(pipeline, request) -> Dict[str, Any]: start_time = time.time() diff --git a/pipelines/rest_api/pipeline/text_to_image.yaml b/pipelines/rest_api/pipeline/text_to_image.yaml new file mode 100644 index 000000000000..959781b2f225 --- /dev/null +++ b/pipelines/rest_api/pipeline/text_to_image.yaml @@ -0,0 +1,16 @@ +version: '1.1.0' + +components: + - name: TextToImageGenerator + params: + ak: + sk: + type: ErnieTextToImageGenerator +pipelines: + - name: query + type: Query + nodes: + - name: TextToImageGenerator + inputs: [Query] + + diff --git a/pipelines/rest_api/schema.py b/pipelines/rest_api/schema.py index 942a4e7029ef..e041d2bad62e 100644 --- a/pipelines/rest_api/schema.py +++ b/pipelines/rest_api/schema.py @@ -83,3 +83,10 @@ class QueryResponse(BaseModel): answers: List[AnswerSerialized] = [] documents: List[DocumentSerialized] = [] debug: Optional[Dict] = Field(None, alias="_debug") + + +class QueryImageResponse(BaseModel): + query: str + answers: List[str] = [] + documents: List[DocumentSerialized] = [] + debug: Optional[Dict] = Field(None, alias="_debug") diff --git a/pipelines/ui/utils.py b/pipelines/ui/utils.py index 1d672613dbb1..540c44cc2247 100644 --- a/pipelines/ui/utils.py +++ b/pipelines/ui/utils.py @@ -31,6 +31,7 @@ DOC_FEEDBACK = "feedback" DOC_UPLOAD = "file-upload" DOC_PARSE = 'files' +IMAGE_REQUEST = 'query_text_to_images' def pipelines_is_ready(): @@ -184,6 +185,35 @@ def semantic_search( return results, response +def text_to_image_search( + query, + resolution="1024*1024", + top_k_images=5, + style="探索无限") -> Tuple[List[Dict[str, Any]], Dict[str, str]]: + """ + Send a prompt text and corresponding parameters to the REST API + """ + url = f"{API_ENDPOINT}/{IMAGE_REQUEST}" + params = { + "TextToImageGenerator": { + "style": style, + "topk": top_k_images, + "resolution": resolution, + } + } + req = {"query": query, "params": params} + response_raw = requests.post(url, json=req) + + if response_raw.status_code >= 400 and response_raw.status_code != 503: + raise Exception(f"{vars(response_raw)}") + + response = response_raw.json() + if "errors" in response: + raise Exception(", ".join(response["errors"])) + results = response["answers"] + return results, response + + def send_feedback(query, answer_obj, is_correct_answer, is_correct_document, document) -> None: """ diff --git a/pipelines/ui/webapp_text_to_image.py b/pipelines/ui/webapp_text_to_image.py new file mode 100644 index 000000000000..60bd521b5e34 --- /dev/null +++ b/pipelines/ui/webapp_text_to_image.py @@ -0,0 +1,104 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import os +import argparse + +from PIL import Image +from utils import text_to_image_search +import gradio as gr + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument("--serving_port", default=8502, type=int, help="Port for the serving.") +args = parser.parse_args() +# yapf: enable + + +def infer(text_prompt, top_k_images, Size, style): + results, raw_json = text_to_image_search(text_prompt, + resolution=Size, + top_k_images=top_k_images, + style=style) + return results + + +def main(): + block = gr.Blocks() + + with block: + with gr.Group(): + with gr.Box(): + with gr.Row().style(mobile_collapse=False, equal_height=True): + text_prompt = gr.Textbox( + label="Enter your prompt", + value='宁静的小镇', + show_label=False, + max_lines=1, + placeholder="Enter your prompt", + ).style( + border=(True, False, True, True), + rounded=(True, False, False, True), + container=False, + ) + btn = gr.Button("开始生成").style( + margin=False, + rounded=(False, True, True, False), + ) + gallery = gr.Gallery(label="Generated images", + show_label=False, + elem_id="gallery").style(grid=[2], + height="auto") + + advanced_button = gr.Button("Advanced options", + elem_id="advanced-btn") + + with gr.Row(elem_id="advanced-options"): + top_k_images = gr.Slider(label="Images", + minimum=1, + maximum=50, + value=5, + step=1) + style = gr.Radio(label='Style', + value='古风', + choices=[ + '古风', '油画', '卡通画', '二次元', "水彩画", "浮世绘", + "蒸汽波艺术", "low poly", "像素风格", "概念艺术", + "未来主义", "赛博朋克", "写实风格", "洛丽塔风格", "巴洛克风格", + "超现实主义", "探索无限" + ]) + Size = gr.Radio(label='Size', + value='1024*1024', + choices=['1024*1024', '1024*1536', '1536*1024']) + + text_prompt.submit(infer, + inputs=[text_prompt, top_k_images, Size, style], + outputs=gallery) + btn.click(infer, + inputs=[text_prompt, top_k_images, Size, style], + outputs=gallery) + advanced_button.click( + None, + [], + text_prompt, + ) + return block + + +if __name__ == "__main__": + block = main() + block.launch(server_name='0.0.0.0', + server_port=args.serving_port, + share=False)