Skip to content

Commit 6b8f7db

Browse files
authored
feat:替代并实现gradio logout route并添加退出按钮 (#1034)
* feat:替代并实现gradio logout route并添加退出按钮 * 在无验证情况下隐藏
1 parent 40a0cc7 commit 6b8f7db

File tree

2 files changed

+126
-1
lines changed

2 files changed

+126
-1
lines changed

ChuanhuChatbot.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from modules import config
1717
import gradio as gr
1818
import colorama
19+
from modules.gradio_patch import reg_patch
1920

21+
reg_patch()
2022

2123
logging.getLogger("httpx").setLevel(logging.WARNING)
2224

@@ -33,6 +35,8 @@ def create_new_model():
3335

3436
with gr.Blocks(theme=small_and_beautiful_theme) as demo:
3537
user_name = gr.Textbox("", visible=False)
38+
# 激活/logout路由
39+
logout_hidden_btn = gr.LogoutButton(visible=False)
3640
promptTemplates = gr.State(load_template(get_template_names()[0], mode=2))
3741
user_question = gr.State("")
3842
assert type(my_api_key) == str
@@ -391,6 +395,8 @@ def create_new_model():
391395
single_turn_checkbox = gr.Checkbox(label=i18n(
392396
"单轮对话"), value=False, elem_classes="switch-checkbox", elem_id="gr-single-session-cb", visible=False)
393397
# checkUpdateBtn = gr.Button(i18n("🔄 检查更新..."), visible=check_update)
398+
logout_btn = gr.Button(
399+
i18n("退出用户"), variant="primary", visible=authflag)
394400

395401
with gr.Tab(i18n("网络")):
396402
gr.Markdown(
@@ -801,7 +807,12 @@ def create_greeting(request: gr.Request):
801807
outputs=[saveFileName, systemPromptTxt, chatbot, single_turn_checkbox, temperature_slider, top_p_slider, n_choices_slider, stop_sequence_txt, max_context_length_slider, max_generation_slider, presence_penalty_slider, frequency_penalty_slider, logit_bias_txt, user_identifier_txt],
802808
_js='(a,b)=>{return bgSelectHistory(a,b);}'
803809
)
804-
810+
logout_btn.click(
811+
fn=None,
812+
inputs=[],
813+
outputs=[],
814+
_js='self.location="/logout"'
815+
)
805816
# 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
806817
demo.title = i18n("川虎Chat 🚀")
807818

modules/gradio_patch.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import logging
2+
import os
3+
4+
import fastapi
5+
import gradio
6+
from fastapi.responses import RedirectResponse
7+
from gradio.oauth import MOCKED_OAUTH_TOKEN
8+
9+
from modules.presets import i18n
10+
11+
OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID")
12+
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET")
13+
OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES")
14+
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL")
15+
def _add_oauth_routes(app: fastapi.FastAPI) -> None:
16+
"""Add OAuth routes to the FastAPI app (login, callback handler and logout)."""
17+
try:
18+
from authlib.integrations.starlette_client import OAuth
19+
except ImportError as e:
20+
raise ImportError(
21+
"Cannot initialize OAuth to due a missing library. Please run `pip install gradio[oauth]` or add "
22+
"`gradio[oauth]` to your requirements.txt file in order to install the required dependencies."
23+
) from e
24+
25+
# Check environment variables
26+
msg = (
27+
"OAuth is required but {} environment variable is not set. Make sure you've enabled OAuth in your Space by"
28+
" setting `hf_oauth: true` in the Space metadata."
29+
)
30+
if OAUTH_CLIENT_ID is None:
31+
raise ValueError(msg.format("OAUTH_CLIENT_ID"))
32+
if OAUTH_CLIENT_SECRET is None:
33+
raise ValueError(msg.format("OAUTH_CLIENT_SECRET"))
34+
if OAUTH_SCOPES is None:
35+
raise ValueError(msg.format("OAUTH_SCOPES"))
36+
if OPENID_PROVIDER_URL is None:
37+
raise ValueError(msg.format("OPENID_PROVIDER_URL"))
38+
39+
# Register OAuth server
40+
oauth = OAuth()
41+
oauth.register(
42+
name="huggingface",
43+
client_id=OAUTH_CLIENT_ID,
44+
client_secret=OAUTH_CLIENT_SECRET,
45+
client_kwargs={"scope": OAUTH_SCOPES},
46+
server_metadata_url=OPENID_PROVIDER_URL + "/.well-known/openid-configuration",
47+
)
48+
49+
# Define OAuth routes
50+
@app.get("/login/huggingface")
51+
async def oauth_login(request: fastapi.Request):
52+
"""Endpoint that redirects to HF OAuth page."""
53+
redirect_uri = str(request.url_for("oauth_redirect_callback"))
54+
if ".hf.space" in redirect_uri:
55+
# In Space, FastAPI redirect as http but we want https
56+
redirect_uri = redirect_uri.replace("http://", "https://")
57+
return await oauth.huggingface.authorize_redirect(request, redirect_uri)
58+
59+
@app.get("/login/callback")
60+
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
61+
"""Endpoint that handles the OAuth callback."""
62+
token = await oauth.huggingface.authorize_access_token(request)
63+
request.session["oauth_profile"] = token["userinfo"]
64+
request.session["oauth_token"] = token
65+
return RedirectResponse("/")
66+
67+
@app.get("/logout")
68+
async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
69+
"""Endpoint that logs out the user (e.g. delete cookie session)."""
70+
request.session.pop("oauth_profile", None)
71+
request.session.pop("oauth_token", None)
72+
# 清除cookie并跳转到首页
73+
response = RedirectResponse(url="/", status_code=302)
74+
response.delete_cookie(key=f"access-token")
75+
response.delete_cookie(key=f"access-token-unsecure")
76+
return response
77+
78+
79+
def _add_mocked_oauth_routes(app: fastapi.FastAPI) -> None:
80+
"""Add fake oauth routes if Gradio is run locally and OAuth is enabled.
81+
82+
Clicking on a gr.LoginButton will have the same behavior as in a Space (i.e. gets redirected in a new tab) but
83+
instead of authenticating with HF, a mocked user profile is added to the session.
84+
"""
85+
86+
# Define OAuth routes
87+
@app.get("/login/huggingface")
88+
async def oauth_login(request: fastapi.Request):
89+
"""Fake endpoint that redirects to HF OAuth page."""
90+
return RedirectResponse("/login/callback")
91+
92+
@app.get("/login/callback")
93+
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
94+
"""Endpoint that handles the OAuth callback."""
95+
request.session["oauth_profile"] = MOCKED_OAUTH_TOKEN["userinfo"]
96+
request.session["oauth_token"] = MOCKED_OAUTH_TOKEN
97+
return RedirectResponse("/")
98+
99+
@app.get("/logout")
100+
async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
101+
"""Endpoint that logs out the user (e.g. delete cookie session)."""
102+
request.session.pop("oauth_profile", None)
103+
request.session.pop("oauth_token", None)
104+
# 清除cookie并跳转到首页
105+
response = RedirectResponse(url="/", status_code=302)
106+
response.delete_cookie(key=f"access-token")
107+
response.delete_cookie(key=f"access-token-unsecure")
108+
return response
109+
110+
111+
def reg_patch():
112+
gradio.oauth._add_mocked_oauth_routes = _add_mocked_oauth_routes
113+
gradio.oauth._add_oauth_routes = _add_oauth_routes
114+
logging.info(i18n("覆盖gradio.oauth /logout路由"))

0 commit comments

Comments
 (0)