Skip to content

Commit

Permalink
Fix image generation in OpenaiChat (#2390)
Browse files Browse the repository at this point in the history
* Fix image generation in OpenaiChat

* Add PollinationsAI provider with image and text generation
  • Loading branch information
hlohaus authored Nov 20, 2024
1 parent ea34697 commit dba41cd
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 52 deletions.
69 changes: 69 additions & 0 deletions g4f/Provider/PollinationsAI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

from urllib.parse import quote
import random
import requests
from sys import maxsize
from aiohttp import ClientSession

from ..typing import AsyncResult, Messages
from ..image import ImageResponse
from ..requests.raise_for_status import raise_for_status
from ..requests.aiohttp import get_connector
from .needs_auth.OpenaiAPI import OpenaiAPI
from .helper import format_prompt

class PollinationsAI(OpenaiAPI):
label = "Pollinations.AI"
url = "https://pollinations.ai"
working = True
supports_stream = True
default_model = "openai"

@classmethod
def get_models(cls):
if not cls.image_models:
url = "https://image.pollinations.ai/models"
response = requests.get(url)
raise_for_status(response)
cls.image_models = response.json()
if not cls.models:
url = "https://text.pollinations.ai/models"
response = requests.get(url)
raise_for_status(response)
cls.models = [model.get("name") for model in response.json()]
cls.models.extend(cls.image_models)
return cls.models

@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
api_base: str = "https://text.pollinations.ai/openai",
api_key: str = None,
proxy: str = None,
seed: str = None,
**kwargs
) -> AsyncResult:
if model:
model = cls.get_model(model)
if model in cls.image_models:
prompt = messages[-1]["content"]
if seed is None:
seed = random.randint(0, maxsize)
image = f"https://image.pollinations.ai/prompt/{quote(prompt)}?width=1024&height=1024&seed={int(seed)}&nofeed=true&nologo=true&model={quote(model)}"
yield ImageResponse(image, prompt)
return
if api_key is None:
async with ClientSession(connector=get_connector(proxy=proxy)) as session:
prompt = format_prompt(messages)
async with session.get(f"https://text.pollinations.ai/{quote(prompt)}?model={quote(model)}") as response:
await raise_for_status(response)
async for line in response.content.iter_any():
yield line.decode(errors="ignore")
else:
async for chunk in super().create_async_generator(
model, messages, api_base=api_base, proxy=proxy, **kwargs
):
yield chunk
1 change: 1 addition & 0 deletions g4f/Provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .PerplexityLabs import PerplexityLabs
from .Pi import Pi
from .Pizzagpt import Pizzagpt
from .PollinationsAI import PollinationsAI
from .Prodia import Prodia
from .Reka import Reka
from .ReplicateHome import ReplicateHome
Expand Down
65 changes: 13 additions & 52 deletions g4f/Provider/needs_auth/OpenaiChat.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
default_vision_model = "gpt-4o"
fallback_models = ["auto", "gpt-4", "gpt-4o", "gpt-4o-mini", "gpt-4o-canmore", "o1-preview", "o1-mini"]
vision_models = fallback_models
image_models = fallback_models

_api_key: str = None
_headers: dict = None
Expand Down Expand Up @@ -330,7 +331,7 @@ async def create_async_generator(
api_key: str = None,
cookies: Cookies = None,
auto_continue: bool = False,
history_disabled: bool = True,
history_disabled: bool = False,
action: str = "next",
conversation_id: str = None,
conversation: Conversation = None,
Expand Down Expand Up @@ -425,12 +426,6 @@ async def create_async_generator(
f"Arkose: {'False' if not need_arkose else RequestConfig.arkose_token[:12]+'...'}",
f"Proofofwork: {'False' if proofofwork is None else proofofwork[:12]+'...'}",
)]
ws = None
if need_arkose:
async with session.post(f"{cls.url}/backend-api/register-websocket", headers=cls._headers) as response:
wss_url = (await response.json()).get("wss_url")
if wss_url:
ws = await session.ws_connect(wss_url)
data = {
"action": action,
"messages": None,
Expand Down Expand Up @@ -474,7 +469,7 @@ async def create_async_generator(
await asyncio.sleep(5)
continue
await raise_for_status(response)
async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, conversation, ws):
async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, conversation):
if return_conversation:
history_disabled = False
return_conversation = False
Expand All @@ -489,44 +484,16 @@ async def create_async_generator(
if history_disabled and auto_continue:
await cls.delete_conversation(session, cls._headers, conversation.conversation_id)

@staticmethod
async def iter_messages_ws(ws: ClientWebSocketResponse, conversation_id: str, is_curl: bool) -> AsyncIterator:
while True:
if is_curl:
message = json.loads(ws.recv()[0])
else:
message = await ws.receive_json()
if message["conversation_id"] == conversation_id:
yield base64.b64decode(message["body"])

@classmethod
async def iter_messages_chunk(
cls,
messages: AsyncIterator,
session: StreamSession,
fields: Conversation,
ws = None
) -> AsyncIterator:
async for message in messages:
if message.startswith(b'{"wss_url":'):
message = json.loads(message)
ws = await session.ws_connect(message["wss_url"]) if ws is None else ws
try:
async for chunk in cls.iter_messages_chunk(
cls.iter_messages_ws(ws, message["conversation_id"], hasattr(ws, "recv")),
session, fields
):
yield chunk
finally:
await ws.aclose() if hasattr(ws, "aclose") else await ws.close()
break
async for chunk in cls.iter_messages_line(session, message, fields):
if fields.finish_reason is not None:
break
else:
yield chunk
if fields.finish_reason is not None:
break
yield chunk

@classmethod
async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation) -> AsyncIterator:
Expand All @@ -542,9 +509,9 @@ async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: C
return
if isinstance(line, dict) and "v" in line:
v = line.get("v")
if isinstance(v, str):
if isinstance(v, str) and fields.is_recipient:
yield v
elif isinstance(v, list):
elif isinstance(v, list) and fields.is_recipient:
for m in v:
if m.get("p") == "/message/content/parts/0":
yield m.get("v")
Expand All @@ -556,25 +523,20 @@ async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: C
fields.conversation_id = v.get("conversation_id")
debug.log(f"OpenaiChat: New conversation: {fields.conversation_id}")
m = v.get("message", {})
if m.get("author", {}).get("role") == "assistant":
fields.message_id = v.get("message", {}).get("id")
fields.is_recipient = m.get("recipient") == "all"
if fields.is_recipient:
c = m.get("content", {})
if c.get("content_type") == "multimodal_text":
generated_images = []
for element in c.get("parts"):
if isinstance(element, str):
debug.log(f"No image or text: {line}")
elif element.get("content_type") == "image_asset_pointer":
if isinstance(element, dict) and element.get("content_type") == "image_asset_pointer":
generated_images.append(
cls.get_generated_image(session, cls._headers, element)
)
elif element.get("content_type") == "text":
for part in element.get("parts", []):
yield part
for image_response in await asyncio.gather(*generated_images):
yield image_response
else:
debug.log(f"OpenaiChat: {line}")
if m.get("author", {}).get("role") == "assistant":
fields.message_id = v.get("message", {}).get("id")
return
if "error" in line and line.get("error"):
raise RuntimeError(line.get("error"))
Expand Down Expand Up @@ -652,7 +614,7 @@ def _create_request_args(cls, cookies: Cookies = None, headers: dict = None, use
cls._headers = cls.get_default_headers() if headers is None else headers
if user_agent is not None:
cls._headers["user-agent"] = user_agent
cls._cookies = {} if cookies is None else {k: v for k, v in cookies.items() if k != "access_token"}
cls._cookies = {} if cookies is None else cookies
cls._update_cookie_header()

@classmethod
Expand All @@ -671,8 +633,6 @@ def _set_api_key(cls, api_key: str):
@classmethod
def _update_cookie_header(cls):
cls._headers["cookie"] = format_cookies(cls._cookies)
if "oai-did" in cls._cookies:
cls._headers["oai-device-id"] = cls._cookies["oai-did"]

class Conversation(BaseConversation):
"""
Expand All @@ -682,6 +642,7 @@ def __init__(self, conversation_id: str = None, message_id: str = None, finish_r
self.conversation_id = conversation_id
self.message_id = message_id
self.finish_reason = finish_reason
self.is_recipient = False

class Response():
"""
Expand Down
1 change: 1 addition & 0 deletions g4f/providers/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ class ProviderModelMixin:
default_model: str = None
models: list[str] = []
model_aliases: dict[str, str] = {}
image_models: list = None

@classmethod
def get_models(cls) -> list[str]:
Expand Down

0 comments on commit dba41cd

Please sign in to comment.