From e69269d844b7089dec636516d6edb4f70911ebf6 Mon Sep 17 00:00:00 2001 From: Nothing1024 Date: Tue, 21 Nov 2023 16:05:23 +0800 Subject: [PATCH 1/7] Feat: OpenAI Base URL supported --- backend/image_generation.py | 12 ++++++------ backend/llm.py | 4 ++-- backend/main.py | 15 ++++++++++++++- frontend/src/App.tsx | 1 + frontend/src/components/SettingsDialog.tsx | 19 +++++++++++++++++++ frontend/src/types.ts | 1 + 6 files changed, 43 insertions(+), 9 deletions(-) diff --git a/backend/image_generation.py b/backend/image_generation.py index 080334fe..ad217720 100644 --- a/backend/image_generation.py +++ b/backend/image_generation.py @@ -5,8 +5,8 @@ from bs4 import BeautifulSoup -async def process_tasks(prompts, api_key): - tasks = [generate_image(prompt, api_key) for prompt in prompts] +async def process_tasks(prompts, api_key, base_url): + tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts] results = await asyncio.gather(*tasks, return_exceptions=True) processed_results = [] @@ -20,8 +20,8 @@ async def process_tasks(prompts, api_key): return processed_results -async def generate_image(prompt, api_key): - client = AsyncOpenAI(api_key=api_key) +async def generate_image(prompt, api_key, base_url): + client = AsyncOpenAI(api_key=api_key, base_url=base_url) image_params = { "model": "dall-e-3", "quality": "standard", @@ -60,7 +60,7 @@ def create_alt_url_mapping(code): return mapping -async def generate_images(code, api_key, image_cache): +async def generate_images(code, api_key, base_url, image_cache): # Find all images soup = BeautifulSoup(code, "html.parser") images = soup.find_all("img") @@ -87,7 +87,7 @@ async def generate_images(code, api_key, image_cache): return code # Generate images - results = await process_tasks(prompts, api_key) + results = await process_tasks(prompts, api_key, base_url) # Create a dict mapping alt text to image URL mapped_image_urls = dict(zip(prompts, results)) diff --git a/backend/llm.py b/backend/llm.py index b52c3c9f..82a765e1 100644 --- a/backend/llm.py +++ b/backend/llm.py @@ -6,9 +6,9 @@ async def stream_openai_response( - messages, api_key: str, callback: Callable[[str], Awaitable[None]] + messages, api_key: str, base_url:str, callback: Callable[[str], Awaitable[None]] ): - client = AsyncOpenAI(api_key=api_key) + client = AsyncOpenAI(api_key=api_key, base_url=base_url) model = MODEL_GPT_4_VISION diff --git a/backend/main.py b/backend/main.py index 4eb09fa1..bd161c86 100644 --- a/backend/main.py +++ b/backend/main.py @@ -73,6 +73,13 @@ async def stream_code_test(websocket: WebSocket): openai_api_key = os.environ.get("OPENAI_API_KEY") if openai_api_key: print("Using OpenAI API key from environment variable") + if params["openAiBaseURL"]: + openai_base_url = params["openAiBaseURL"] + print("Using OpenAI Base URL from client-side settings dialog") + else: + openai_base_url = os.environ.get("OPENAI_BASE_URL") + if openai_base_url: + print("Using OpenAI Base URL from environment variable") if not openai_api_key: print("OpenAI API key not found") @@ -83,6 +90,11 @@ async def stream_code_test(websocket: WebSocket): } ) return + # openai_base_url="https://flag.smarttrot.com/v1" + if not openai_base_url: + openai_base_url = None + print("Using Offical OpenAI Base URL") + should_generate_images = ( params["isImageGenerationEnabled"] @@ -117,6 +129,7 @@ async def process_chunk(content): completion = await stream_openai_response( prompt_messages, api_key=openai_api_key, + base_url = openai_base_url, callback=lambda x: process_chunk(x), ) @@ -129,7 +142,7 @@ async def process_chunk(content): {"type": "status", "value": "Generating images..."} ) updated_html = await generate_images( - completion, api_key=openai_api_key, image_cache=image_cache + completion, api_key=openai_api_key, base_url=openai_base_url, image_cache=image_cache ) else: updated_html = completion diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index c2ec3feb..7e89041e 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -37,6 +37,7 @@ function App() { const [settings, setSettings] = usePersistedState( { openAiApiKey: null, + openAiBaseURL: null, screenshotOneApiKey: null, isImageGenerationEnabled: true, editorTheme: "cobalt", diff --git a/frontend/src/components/SettingsDialog.tsx b/frontend/src/components/SettingsDialog.tsx index f7004d3b..411dad0c 100644 --- a/frontend/src/components/SettingsDialog.tsx +++ b/frontend/src/components/SettingsDialog.tsx @@ -76,6 +76,25 @@ function SettingsDialog({ settings, setSettings }: Props) { } /> + + + + setSettings((s) => ({ + ...s, + openAiBaseURL: e.target.value, + })) + } + /> +