-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
3 changed files
with
253 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import datetime | ||
import os | ||
|
||
from include.tasks import split | ||
from include.tasks.extract.astro_forum_docs import get_forum_df | ||
from include.tasks.extract.utils.weaviate.ask_astro_weaviate_hook import AskAstroWeaviateHook | ||
|
||
from airflow.decorators import dag, task | ||
|
||
ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "dev") | ||
|
||
_WEAVIATE_CONN_ID = f"weaviate_{ask_astro_env}" | ||
WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsDev") | ||
ask_astro_weaviate_hook = AskAstroWeaviateHook(_WEAVIATE_CONN_ID) | ||
|
||
blog_cutoff_date = datetime.date(2022, 1, 1) | ||
|
||
default_args = {"retries": 3, "retry_delay": 30} | ||
|
||
schedule_interval = "0 5 * * *" if ask_astro_env == "prod" else None | ||
|
||
|
||
@task | ||
def get_astro_forum_content(): | ||
return get_forum_df() | ||
|
||
|
||
@dag( | ||
schedule_interval=schedule_interval, | ||
start_date=datetime.datetime(2023, 9, 27), | ||
catchup=False, | ||
is_paused_upon_creation=True, | ||
default_args=default_args, | ||
) | ||
def ask_astro_load_astro_forum(): | ||
split_docs = task(split.split_html).expand(dfs=[get_astro_forum_content()]) | ||
|
||
_import_data = ( | ||
task(ask_astro_weaviate_hook.ingest_data, retries=10) | ||
.partial( | ||
class_name=WEAVIATE_CLASS, | ||
existing="upsert", | ||
doc_key="docLink", | ||
batch_params={"batch_size": 1000}, | ||
verbose=True, | ||
) | ||
.expand(dfs=[split_docs]) | ||
) | ||
|
||
|
||
ask_astro_load_astro_forum() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from datetime import datetime | ||
|
||
import pandas as pd | ||
import pytz | ||
import requests | ||
from bs4 import BeautifulSoup | ||
from weaviate.util import generate_uuid5 | ||
|
||
cutoff_date = datetime(2022, 1, 1, tzinfo=pytz.UTC) | ||
|
||
logger = logging.getLogger("airflow.task") | ||
|
||
|
||
def get_questions_urls(html_content: str) -> list[str]: | ||
""" | ||
Extracts question URLs from HTML content using BeautifulSoup. | ||
param html_content (str): The HTML content of a web page. | ||
""" | ||
soup = BeautifulSoup(html_content, "html.parser") | ||
return [a_tag.attrs.get("href") for a_tag in soup.findAll("a", class_="title raw-link raw-topic-link")] | ||
|
||
|
||
def get_publish_date(html_content) -> datetime: | ||
""" | ||
Extracts and parses the publish date from HTML content | ||
html_content (str): The HTML content of a web page. | ||
""" | ||
soup = BeautifulSoup(html_content, "html.parser") | ||
# TODO: use og:article:tag for tag filter | ||
publish_date = soup.find("meta", property="article:published_time")["content"] | ||
publish_date = datetime.fromisoformat(publish_date) | ||
return publish_date | ||
|
||
|
||
def filter_cutoff_questions(questions_urls: list[str]) -> list[str]: | ||
""" | ||
Filters a list of question URLs based on the publish dates. | ||
param questions_urls (list[str]): A list of question URLs. | ||
""" | ||
filter_questions_urls = [] | ||
|
||
for question_url in questions_urls: | ||
try: | ||
html_content = requests.get(question_url).content | ||
except requests.RequestException as e: | ||
logger.error(f"Error fetching content for {question_url}: {e}") | ||
continue # Move on to the next iteration | ||
|
||
soup = BeautifulSoup(html_content, "html.parser") | ||
reply = soup.find("div", itemprop="comment") | ||
if not reply: | ||
logger.info(f"No response, Ignoring {question_url}") | ||
continue | ||
|
||
if get_publish_date(html_content) >= cutoff_date: | ||
filter_questions_urls.append(question_url) | ||
|
||
return filter_questions_urls | ||
|
||
|
||
def get_cutoff_questions(forum_url: str) -> set[str]: | ||
""" | ||
Retrieves a set of valid question URLs from a forum page. | ||
param forum_url (str): The URL of the forum. | ||
""" | ||
page_number = 0 | ||
base_url = f"{forum_url}?page=" | ||
all_valid_url = [] | ||
while True: | ||
page_url = f"{base_url}{page_number}" | ||
logger.info(page_url) | ||
page_number = page_number + 1 | ||
html_content = requests.get(page_url).content | ||
questions_urls = get_questions_urls(html_content) | ||
if not questions_urls: # reached at the end of page | ||
return set(all_valid_url) | ||
filter_questions_urls = filter_cutoff_questions(questions_urls) | ||
all_valid_url.extend(filter_questions_urls) | ||
|
||
|
||
def truncate_tokens(text: str, encoding_name: str, max_length: int = 8192) -> str: | ||
""" | ||
Truncates a text string based on the maximum number of tokens. | ||
param string (str): The input text string to be truncated. | ||
param encoding_name (str): The name of the encoding model. | ||
param max_length (int): The maximum number of tokens allowed. Default is 8192. | ||
""" | ||
import tiktoken | ||
|
||
try: | ||
encoding = tiktoken.encoding_for_model(encoding_name) | ||
except ValueError as e: | ||
raise ValueError(f"Invalid encoding_name: {e}") | ||
|
||
encoded_string = encoding.encode(text) | ||
num_tokens = len(encoded_string) | ||
|
||
if num_tokens > max_length: | ||
text = encoding.decode(encoded_string[:max_length]) | ||
|
||
return text | ||
|
||
|
||
def clean_content(row_content: str) -> str | None: | ||
""" | ||
Cleans and extracts text content from HTML. | ||
param row_content (str): The HTML content to be cleaned. | ||
""" | ||
soup = BeautifulSoup(row_content, "html.parser").find("body") | ||
|
||
if soup is None: | ||
return | ||
# Remove script and style tags | ||
for script_or_style in soup(["script", "style"]): | ||
script_or_style.extract() | ||
|
||
# Get text and handle whitespaces | ||
text = " ".join(soup.stripped_strings) | ||
# Need to truncate because in some cases the token size | ||
# exceeding the max token size. Better solution can be get summary and ingest it. | ||
return truncate_tokens(text, "gpt-3.5-turbo", 7692) | ||
|
||
|
||
def fetch_url_content(url) -> str | None: | ||
""" | ||
Fetches the content of a URL. | ||
param url (str): The URL to fetch content from. | ||
""" | ||
try: | ||
response = requests.get(url) | ||
response.raise_for_status() # Raise an HTTPError for bad responses | ||
return response.content | ||
except requests.RequestException: | ||
logger.info("Error fetching content for %s: %s", url, url) | ||
return None | ||
|
||
|
||
def process_url(url: str, doc_source: str = "") -> dict | None: | ||
""" | ||
Process a URL by fetching its content, cleaning it, and generating a unique identifier (SHA) based on the cleaned content. | ||
param url (str): The URL to be processed. | ||
""" | ||
content = fetch_url_content(url) | ||
if content is not None: | ||
cleaned_content = clean_content(content) | ||
sha = generate_uuid5(cleaned_content) | ||
return {"docSource": doc_source, "sha": sha, "content": cleaned_content, "docLink": url} | ||
|
||
|
||
def url_to_df(urls: set[str], doc_source: str = "") -> pd.DataFrame: | ||
""" | ||
Create a DataFrame from a list of URLs by processing each URL and organizing the results. | ||
param urls (list): A list of URLs to be processed. | ||
""" | ||
df_data = [process_url(url, doc_source) for url in urls] | ||
df_data = [entry for entry in df_data if entry is not None] # Remove failed entries | ||
df = pd.DataFrame(df_data) | ||
df = df[["docSource", "sha", "content", "docLink"]] # Reorder columns if needed | ||
return df | ||
|
||
|
||
def get_forum_df() -> list[pd.DataFrame]: | ||
""" | ||
Retrieves question links from a forum, converts them into a DataFrame, and returns a list containing the DataFrame. | ||
""" | ||
questions_links = get_cutoff_questions("https://forum.astronomer.io/latest") | ||
logger.info(questions_links) | ||
df = url_to_df(questions_links, "astro-forum") | ||
return [df] |