Skip to content

Commit e57a623

Browse files
authored
Merge pull request #9 from import-ai/feature/reader_timeout
Add timeout to html reader
2 parents 955e6cf + 330313e commit e57a623

File tree

3 files changed

+53
-22
lines changed

3 files changed

+53
-22
lines changed

tests/function/test_html_reader.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,20 @@
88
from common import project_root
99
from common.trace_info import TraceInfo
1010
from tests.helper.fixture import trace_info
11-
from wizard.config import OpenAIConfig
11+
from wizard.config import OpenAIConfig, ReaderConfig
1212
from wizard.entity import Task
1313
from wizard.wand.functions.html_reader import HTMLReader
1414

1515

1616
@pytest.fixture(scope="function")
17-
def openai_config() -> OpenAIConfig:
17+
def reader_config() -> ReaderConfig:
1818
load_dotenv(dotenv_path=project_root.path(".env"))
19-
return OpenAIConfig(
20-
api_key=os.environ["MBW_TASK_READER_API_KEY"],
21-
base_url=os.environ["MBW_TASK_READER_BASE_URL"],
22-
model=os.environ["MBW_TASK_READER_MODEL"],
19+
return ReaderConfig(
20+
openai=OpenAIConfig(
21+
api_key=os.environ["MBW_TASK_READER_OPENAI_API_KEY"],
22+
base_url=os.environ["MBW_TASK_READER_OPENAI_BASE_URL"],
23+
model=os.environ["MBW_TASK_READER_OPENAI_MODEL"],
24+
)
2325
)
2426

2527

@@ -29,15 +31,15 @@ def task() -> Task:
2931
return pickle.load(f)
3032

3133

32-
async def test_html_reader(openai_config: OpenAIConfig, task: Task, trace_info: TraceInfo):
33-
c = HTMLReader(openai_config)
34+
async def test_html_reader(reader_config: ReaderConfig, task: Task, trace_info: TraceInfo):
35+
c = HTMLReader(reader_config)
3436
result = await c.run(task, trace_info)
3537
print(jsonlib.dumps(result, ensure_ascii=False, separators=(",", ":")))
3638
# assert "Implement a notification system for updates and alerts." in result["markdown"]
3739

3840

39-
async def test_html_clean(openai_config: OpenAIConfig, task: Task):
40-
c = HTMLReader(openai_config)
41+
async def test_html_clean(reader_config: ReaderConfig, task: Task):
42+
c = HTMLReader(reader_config)
4143
html = task.input["html"]
4244
url = task.input["url"]
4345
print(f"raw length: {len(html)}")

wizard/config.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,13 @@ class BackendConfig(BaseModel):
3535
base_url: str
3636

3737

38+
class ReaderConfig(BaseModel):
39+
openai: OpenAIConfig
40+
timeout: float = Field(default=180, description="timeout second for reading html")
41+
42+
3843
class TaskConfig(BaseModel):
39-
reader: OpenAIConfig
44+
reader: ReaderConfig
4045

4146

4247
class Config(BaseModel):

wizard/wand/functions/html_reader.py

+35-11
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import asyncio
22
import json as jsonlib
33
import re
4+
from urllib.parse import urlparse
45

56
from bs4 import BeautifulSoup, Comment, Tag
67
from openai import AsyncOpenAI
78

89
from common.trace_info import TraceInfo
9-
from wizard.config import OpenAIConfig
10+
from wizard.config import OpenAIConfig, ReaderConfig
1011
from wizard.entity import Task
1112
from wizard.wand.functions.base_function import BaseFunction
1213

@@ -46,13 +47,17 @@ class HTMLReader(BaseFunction):
4647
"news.qq.com": {
4748
"name": "div",
4849
"class_": "content-article"
50+
},
51+
"zhuanlan.zhihu.com": {
52+
"name": "article"
4953
}
50-
5154
}
5255

53-
def __init__(self, openai_config: OpenAIConfig):
56+
def __init__(self, reader_config: ReaderConfig):
57+
openai_config: OpenAIConfig = reader_config.openai
5458
self.client = AsyncOpenAI(api_key=openai_config.api_key, base_url=openai_config.base_url)
55-
self.model = openai_config.model
59+
self.model: str = openai_config.model
60+
self.timeout: float = reader_config.timeout
5661

5762
@classmethod
5863
def content_selector(cls, url: str, soup: BeautifulSoup) -> Tag:
@@ -127,7 +132,9 @@ def create_prompt(cls, text: str, instruction: str = None, schema: str = None) -
127132
if not instruction:
128133
instruction = "Extract the main content from the given HTML and convert it to Markdown format."
129134
if schema:
130-
instruction = "Extract the specified information from the given HTML and present it in a structured JSON format. If any of the fields are not found in the HTML document, set their values to `Unknown` in the JSON output."
135+
instruction = ("Extract the specified information from the given HTML and present it in a structured JSON "
136+
"format. If any of the fields are not found in the HTML document, set their values to "
137+
"`Unknown` in the JSON output.")
131138
prompt = f"{instruction}\n```html\n{text}\n```\nThe JSON schema is as follows:```json\n{schema}\n```"
132139
else:
133140
prompt = f"{instruction}\n```html\n{text}\n```"
@@ -176,14 +183,31 @@ async def run(self, task: Task, trace_info: TraceInfo, stream: bool = False) ->
176183
html = input_dict["html"]
177184
url = input_dict["url"]
178185

186+
domain: str = urlparse(url).netloc
187+
trace_info = trace_info.bind(domain=domain)
188+
179189
cleaned_html = self.clean_html(url, html, clean_svg=True, clean_base64=True, remove_atts=True,
180190
compress=True, remove_empty_tag=True, enable_content_selector=True)
181-
trace_info.info({"len(html)": len(html), "len(cleaned_html)": len(cleaned_html)})
182-
183-
metadata, content = await asyncio.gather(
184-
self.extract_content(cleaned_html, schema=self.SCHEMA),
185-
self.extract_content(cleaned_html, stream=stream)
186-
)
191+
trace_info.info({
192+
"len(html)": len(html),
193+
"len(cleaned_html)": len(cleaned_html),
194+
"compress_rate": f"{len(cleaned_html) * 100 / len(html): .2f}%"
195+
})
196+
197+
metadata_task = asyncio.create_task(self.extract_content(cleaned_html, schema=self.SCHEMA))
198+
content_task = asyncio.create_task(self.extract_content(cleaned_html, stream=stream))
199+
200+
try:
201+
metadata = await asyncio.wait_for(metadata_task, timeout=self.timeout)
202+
except asyncio.TimeoutError:
203+
trace_info.error({"error": "metadata TimeoutError"})
204+
metadata = {}
205+
206+
try:
207+
content = await asyncio.wait_for(content_task, timeout=self.timeout)
208+
except asyncio.TimeoutError:
209+
trace_info.error({"error": "content TimeoutError"})
210+
content = "Timeout, please retry."
187211

188212
filtered_metadata: dict = {k: v for k, v in metadata.items() if v != "Unknown"}
189213

0 commit comments

Comments
 (0)