Skip to content

Commit

Permalink
Improve download of generated images, serve images in the api (#2391)
Browse files Browse the repository at this point in the history
* Improve download of generated images, serve images in the api
Add support for conversation handling in the api

* Add orginal prompt to image response

* Add download images option in gui, fix loading model list in Airforce

* Add download images option in gui, fix loading model list in Airforce
  • Loading branch information
hlohaus authored Nov 20, 2024
1 parent c959d9b commit ffb4b0d
Show file tree
Hide file tree
Showing 29 changed files with 494 additions and 328 deletions.
35 changes: 31 additions & 4 deletions etc/examples/api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import requests
import json
import uuid

url = "http://localhost:1337/v1/chat/completions"
conversation_id = str(uuid.uuid4())
body = {
"model": "",
"provider": "",
"provider": "Copilot",
"stream": True,
"messages": [
{"role": "user", "content": "What can you do? Who are you?"}
]
{"role": "user", "content": "Hello, i am Heiner. How are you?"}
],
"conversation_id": conversation_id
}
response = requests.post(url, json=body, stream=True)
response.raise_for_status()
Expand All @@ -21,4 +25,27 @@
print(json_data.get("choices", [{"delta": {}}])[0]["delta"].get("content", ""), end="")
except json.JSONDecodeError:
pass
print()
print()
print()
print()
body = {
"model": "",
"provider": "Copilot",
"stream": True,
"messages": [
{"role": "user", "content": "Tell me somethings about my name"}
],
"conversation_id": conversation_id
}
response = requests.post(url, json=body, stream=True)
response.raise_for_status()
for line in response.iter_lines():
if line.startswith(b"data: "):
try:
json_data = json.loads(line[6:])
if json_data.get("error"):
print(json_data)
break
print(json_data.get("choices", [{"delta": {}}])[0]["delta"].get("content", ""), end="")
except json.JSONDecodeError:
pass
24 changes: 12 additions & 12 deletions g4f/Provider/Airforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
working = True
supports_system_message = True
supports_message_history = True

@classmethod
def fetch_completions_models(cls):
response = requests.get('https://api.airforce/models', verify=False)
Expand All @@ -34,19 +34,20 @@ def fetch_imagine_models(cls):
response.raise_for_status()
return response.json()

completions_models = fetch_completions_models.__func__(None)
imagine_models = fetch_imagine_models.__func__(None)

default_model = "gpt-4o-mini"
default_image_model = "flux"
additional_models_imagine = ["stable-diffusion-xl-base", "stable-diffusion-xl-lightning", "Flux-1.1-Pro"]
text_models = completions_models
image_models = [*imagine_models, *additional_models_imagine]
models = [
*text_models,
*image_models,
]


@classmethod
def get_models(cls):
if not cls.models:
cls.image_models = [*cls.fetch_imagine_models(), *cls.additional_models_imagine]
cls.models = [
*cls.fetch_completions_models(),
*cls.image_models
]
return cls.models

model_aliases = {
### completions ###
# openchat
Expand Down Expand Up @@ -100,7 +101,6 @@ def create_async_generator(
**kwargs
) -> AsyncResult:
model = cls.get_model(model)

if model in cls.image_models:
return cls._generate_image(model, messages, proxy, seed, size)
else:
Expand Down
44 changes: 21 additions & 23 deletions g4f/Provider/Blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True
supports_message_history = True
_last_validated_value = None

default_model = 'blackboxai'
default_vision_model = default_model
default_image_model = 'Image Generation'
image_models = ['Image Generation', 'repomap']
vision_models = [default_model, 'gpt-4o', 'gemini-pro', 'gemini-1.5-flash', 'llama-3.1-8b', 'llama-3.1-70b', 'llama-3.1-405b']

userSelectedModel = ['gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro']

agentMode = {
'Image Generation': {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
}
Expand Down Expand Up @@ -77,22 +77,21 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
}

additional_prefixes = {
'gpt-4o': '@gpt-4o',
'gemini-pro': '@gemini-pro',
'claude-sonnet-3.5': '@claude-sonnet'
}
'gpt-4o': '@gpt-4o',
'gemini-pro': '@gemini-pro',
'claude-sonnet-3.5': '@claude-sonnet'
}

model_prefixes = {
**{mode: f"@{value['id']}" for mode, value in trendingAgentMode.items()
if mode not in ["gemini-1.5-flash", "llama-3.1-8b", "llama-3.1-70b", "llama-3.1-405b", "repomap"]},
**additional_prefixes
}
**{
mode: f"@{value['id']}" for mode, value in trendingAgentMode.items()
if mode not in ["gemini-1.5-flash", "llama-3.1-8b", "llama-3.1-70b", "llama-3.1-405b", "repomap"]
},
**additional_prefixes
}


models = list(dict.fromkeys([default_model, *userSelectedModel, *list(agentMode.keys()), *list(trendingAgentMode.keys())]))



model_aliases = {
"gemini-flash": "gemini-1.5-flash",
"claude-3.5-sonnet": "claude-sonnet-3.5",
Expand Down Expand Up @@ -131,12 +130,11 @@ async def fetch_validated(cls):

return cls._last_validated_value


@staticmethod
def generate_id(length=7):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))

@classmethod
def add_prefix_to_messages(cls, messages: Messages, model: str) -> Messages:
prefix = cls.model_prefixes.get(model, "")
Expand All @@ -157,6 +155,7 @@ async def create_async_generator(
cls,
model: str,
messages: Messages,
prompt: str = None,
proxy: str = None,
web_search: bool = False,
image: ImageType = None,
Expand Down Expand Up @@ -191,7 +190,7 @@ async def create_async_generator(
'sec-fetch-site': 'same-origin',
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36'
}

data = {
"messages": messages,
"id": message_id,
Expand Down Expand Up @@ -221,26 +220,25 @@ async def create_async_generator(
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
response.raise_for_status()
response_text = await response.text()

if model in cls.image_models:
image_matches = re.findall(r'!\[.*?\]\((https?://[^\)]+)\)', response_text)
if image_matches:
image_url = image_matches[0]
image_response = ImageResponse(images=[image_url], alt="Generated Image")
yield image_response
yield ImageResponse(image_url, prompt)
return

response_text = re.sub(r'Generated by BLACKBOX.AI, try unlimited chat https://www.blackbox.ai', '', response_text, flags=re.DOTALL)

json_match = re.search(r'\$~~~\$(.*?)\$~~~\$', response_text, re.DOTALL)
if json_match:
search_results = json.loads(json_match.group(1))
answer = response_text.split('$~~~$')[-1].strip()

formatted_response = f"{answer}\n\n**Source:**"
for i, result in enumerate(search_results, 1):
formatted_response += f"\n{i}. {result['title']}: {result['link']}"

yield formatted_response
else:
yield response_text.strip()
9 changes: 8 additions & 1 deletion g4f/Provider/Copilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def create_completion(
image: ImageType = None,
conversation: Conversation = None,
return_conversation: bool = False,
web_search: bool = True,
**kwargs
) -> CreateResult:
if not has_curl_cffi:
Expand Down Expand Up @@ -124,12 +125,14 @@ def create_completion(
is_started = False
msg = None
image_prompt: str = None
last_msg = None
while True:
try:
msg = wss.recv()[0]
msg = json.loads(msg)
except:
break
last_msg = msg
if msg.get("event") == "appendText":
is_started = True
yield msg.get("text")
Expand All @@ -139,8 +142,12 @@ def create_completion(
yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
elif msg.get("event") == "done":
break
elif msg.get("event") == "error":
raise RuntimeError(f"Error: {msg}")
elif msg.get("event") not in ["received", "startMessage", "citation", "partCompleted"]:
debug.log(f"Copilot Message: {msg}")
if not is_started:
raise RuntimeError(f"Last message: {msg}")
raise RuntimeError(f"Invalid response: {last_msg}")

@classmethod
async def get_access_token_and_cookies(cls, proxy: str = None):
Expand Down
7 changes: 4 additions & 3 deletions g4f/Provider/PollinationsAI.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from urllib.parse import quote
import random
import requests
from sys import maxsize
from aiohttp import ClientSession

from ..typing import AsyncResult, Messages
Expand Down Expand Up @@ -40,6 +39,7 @@ async def create_async_generator(
cls,
model: str,
messages: Messages,
prompt: str = None,
api_base: str = "https://text.pollinations.ai/openai",
api_key: str = None,
proxy: str = None,
Expand All @@ -49,9 +49,10 @@ async def create_async_generator(
if model:
model = cls.get_model(model)
if model in cls.image_models:
prompt = messages[-1]["content"]
if prompt is None:
prompt = messages[-1]["content"]
if seed is None:
seed = random.randint(0, maxsize)
seed = random.randint(0, 100000)
image = f"https://image.pollinations.ai/prompt/{quote(prompt)}?width=1024&height=1024&seed={int(seed)}&nofeed=true&nologo=true&model={quote(model)}"
yield ImageResponse(image, prompt)
return
Expand Down
46 changes: 19 additions & 27 deletions g4f/Provider/ReplicateHome.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..requests.aiohttp import get_connector
from ..requests.raise_for_status import raise_for_status
from .helper import format_prompt
from ..image import ImageResponse

Expand All @@ -32,10 +34,8 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
'yorickvp/llava-13b',
]



models = text_models + image_models

model_aliases = {
# image_models
"sd-3": "stability-ai/stable-diffusion-3",
Expand All @@ -56,23 +56,14 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
# text_models
"google-deepmind/gemma-2b-it": "dff94eaf770e1fc211e425a50b51baa8e4cac6c39ef074681f9e39d778773626",
"yorickvp/llava-13b": "80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",

}

@classmethod
def get_model(cls, model: str) -> str:
if model in cls.models:
return model
elif model in cls.model_aliases:
return cls.model_aliases[model]
else:
return cls.default_model

@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
prompt: str = None,
proxy: str = None,
**kwargs
) -> AsyncResult:
Expand All @@ -96,29 +87,30 @@ async def create_async_generator(
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36"
}

async with ClientSession(headers=headers) as session:
if model in cls.image_models:
prompt = messages[-1]['content'] if messages else ""
else:
prompt = format_prompt(messages)

async with ClientSession(headers=headers, connector=get_connector(proxy=proxy)) as session:
if prompt is None:
if model in cls.image_models:
prompt = messages[-1]['content']
else:
prompt = format_prompt(messages)

data = {
"model": model,
"version": cls.model_versions[model],
"input": {"prompt": prompt},
}
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
response.raise_for_status()

async with session.post(cls.api_endpoint, json=data) as response:
await raise_for_status(response)
result = await response.json()
prediction_id = result['id']

poll_url = f"https://homepage.replicate.com/api/poll?id={prediction_id}"
max_attempts = 30
delay = 5
for _ in range(max_attempts):
async with session.get(poll_url, proxy=proxy) as response:
response.raise_for_status()
async with session.get(poll_url) as response:
await raise_for_status(response)
try:
result = await response.json()
except ContentTypeError:
Expand All @@ -131,7 +123,7 @@ async def create_async_generator(
if result['status'] == 'succeeded':
if model in cls.image_models:
image_url = result['output'][0]
yield ImageResponse(image_url, "Generated image")
yield ImageResponse(image_url, prompt)
return
else:
for chunk in result['output']:
Expand All @@ -140,6 +132,6 @@ async def create_async_generator(
elif result['status'] == 'failed':
raise Exception(f"Prediction failed: {result.get('error')}")
await asyncio.sleep(delay)

if result['status'] != 'succeeded':
raise Exception("Prediction timed out")
Loading

0 comments on commit ffb4b0d

Please sign in to comment.