Skip to content

Commit

Permalink
Add Astro forum docs (#212)
Browse files Browse the repository at this point in the history
closes: #120
  • Loading branch information
pankajastro authored Dec 20, 2023
1 parent 14ad2bf commit 441def6
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 1 deletion.
51 changes: 51 additions & 0 deletions airflow/dags/ingestion/ask-astro-forum-load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import datetime
import os

from include.tasks import split
from include.tasks.extract.astro_forum_docs import get_forum_df
from include.tasks.extract.utils.weaviate.ask_astro_weaviate_hook import AskAstroWeaviateHook

from airflow.decorators import dag, task

ask_astro_env = os.environ.get("ASK_ASTRO_ENV", "dev")

_WEAVIATE_CONN_ID = f"weaviate_{ask_astro_env}"
WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "DocsDev")
ask_astro_weaviate_hook = AskAstroWeaviateHook(_WEAVIATE_CONN_ID)

blog_cutoff_date = datetime.date(2022, 1, 1)

default_args = {"retries": 3, "retry_delay": 30}

schedule_interval = "0 5 * * *" if ask_astro_env == "prod" else None


@task
def get_astro_forum_content():
return get_forum_df()


@dag(
schedule_interval=schedule_interval,
start_date=datetime.datetime(2023, 9, 27),
catchup=False,
is_paused_upon_creation=True,
default_args=default_args,
)
def ask_astro_load_astro_forum():
split_docs = task(split.split_html).expand(dfs=[get_astro_forum_content()])

_import_data = (
task(ask_astro_weaviate_hook.ingest_data, retries=10)
.partial(
class_name=WEAVIATE_CLASS,
existing="upsert",
doc_key="docLink",
batch_params={"batch_size": 1000},
verbose=True,
)
.expand(dfs=[split_docs])
)


ask_astro_load_astro_forum()
22 changes: 21 additions & 1 deletion airflow/dags/ingestion/ask-astro-load.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd
from include.tasks import split
from include.tasks.extract import airflow_docs, astro_cli_docs, blogs, github, registry, stack_overflow
from include.tasks.extract.astro_forum_docs import get_forum_df
from include.tasks.extract.astro_sdk_docs import extract_astro_sdk_docs
from include.tasks.extract.astronomer_providers_docs import extract_provider_docs
from include.tasks.extract.utils.weaviate.ask_astro_weaviate_hook import AskAstroWeaviateHook
Expand Down Expand Up @@ -150,6 +151,7 @@ def check_seed_baseline(seed_baseline_url: str = None) -> str | set:
"extract_astro_cli_docs",
"extract_astro_sdk_doc",
"extract_astro_provider_doc",
"extract_astro_forum_doc",
}

@task(trigger_rule="none_failed")
Expand Down Expand Up @@ -240,6 +242,17 @@ def extract_stack_overflow(tag: str, stackoverflow_cutoff_date: str = stackoverf

return df

@task(trigger_rule="none_failed")
def extract_astro_forum_doc():
astro_forum_parquet_path = "include/data/astronomer/docs/astro-forum.parquet"
try:
df = pd.read_parquet(astro_forum_parquet_path)
except Exception:
df = get_forum_df()[0]
df.to_parquet(astro_forum_parquet_path)

return [df]

@task(trigger_rule="none_failed")
def extract_github_issues(repo_base: str):
parquet_file = f"include/data/{repo_base}/issues.parquet"
Expand Down Expand Up @@ -311,6 +324,7 @@ def extract_astro_blogs():
_astro_cli_docs = extract_astro_cli_docs()
_extract_astro_sdk_docs = extract_astro_sdk_doc()
_extract_astro_providers_docs = extract_astro_provider_doc()
_astro_forum_docs = extract_astro_forum_doc()

_get_schema = get_schema_and_process(schema_file="include/data/schema.json")
_check_schema = check_schema(class_objects=_get_schema)
Expand All @@ -325,7 +339,13 @@ def extract_astro_blogs():
registry_cells_docs,
]

html_tasks = [_airflow_docs, _astro_cli_docs, _extract_astro_sdk_docs, _extract_astro_providers_docs]
html_tasks = [
_airflow_docs,
_astro_cli_docs,
_extract_astro_sdk_docs,
_extract_astro_providers_docs,
_astro_forum_docs,
]

python_code_tasks = [registry_dags_docs, code_samples]

Expand Down
181 changes: 181 additions & 0 deletions airflow/include/tasks/extract/astro_forum_docs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from __future__ import annotations

import logging
from datetime import datetime

import pandas as pd
import pytz
import requests
from bs4 import BeautifulSoup
from weaviate.util import generate_uuid5

cutoff_date = datetime(2022, 1, 1, tzinfo=pytz.UTC)

logger = logging.getLogger("airflow.task")


def get_questions_urls(html_content: str) -> list[str]:
"""
Extracts question URLs from HTML content using BeautifulSoup.
param html_content (str): The HTML content of a web page.
"""
soup = BeautifulSoup(html_content, "html.parser")
return [a_tag.attrs.get("href") for a_tag in soup.findAll("a", class_="title raw-link raw-topic-link")]


def get_publish_date(html_content) -> datetime:
"""
Extracts and parses the publish date from HTML content
html_content (str): The HTML content of a web page.
"""
soup = BeautifulSoup(html_content, "html.parser")
# TODO: use og:article:tag for tag filter
publish_date = soup.find("meta", property="article:published_time")["content"]
publish_date = datetime.fromisoformat(publish_date)
return publish_date


def filter_cutoff_questions(questions_urls: list[str]) -> list[str]:
"""
Filters a list of question URLs based on the publish dates.
param questions_urls (list[str]): A list of question URLs.
"""
filter_questions_urls = []

for question_url in questions_urls:
try:
html_content = requests.get(question_url).content
except requests.RequestException as e:
logger.error(f"Error fetching content for {question_url}: {e}")
continue # Move on to the next iteration

soup = BeautifulSoup(html_content, "html.parser")
reply = soup.find("div", itemprop="comment")
if not reply:
logger.info(f"No response, Ignoring {question_url}")
continue

if get_publish_date(html_content) >= cutoff_date:
filter_questions_urls.append(question_url)

return filter_questions_urls


def get_cutoff_questions(forum_url: str) -> set[str]:
"""
Retrieves a set of valid question URLs from a forum page.
param forum_url (str): The URL of the forum.
"""
page_number = 0
base_url = f"{forum_url}?page="
all_valid_url = []
while True:
page_url = f"{base_url}{page_number}"
logger.info(page_url)
page_number = page_number + 1
html_content = requests.get(page_url).content
questions_urls = get_questions_urls(html_content)
if not questions_urls: # reached at the end of page
return set(all_valid_url)
filter_questions_urls = filter_cutoff_questions(questions_urls)
all_valid_url.extend(filter_questions_urls)


def truncate_tokens(text: str, encoding_name: str, max_length: int = 8192) -> str:
"""
Truncates a text string based on the maximum number of tokens.
param string (str): The input text string to be truncated.
param encoding_name (str): The name of the encoding model.
param max_length (int): The maximum number of tokens allowed. Default is 8192.
"""
import tiktoken

try:
encoding = tiktoken.encoding_for_model(encoding_name)
except ValueError as e:
raise ValueError(f"Invalid encoding_name: {e}")

encoded_string = encoding.encode(text)
num_tokens = len(encoded_string)

if num_tokens > max_length:
text = encoding.decode(encoded_string[:max_length])

return text


def clean_content(row_content: str) -> str | None:
"""
Cleans and extracts text content from HTML.
param row_content (str): The HTML content to be cleaned.
"""
soup = BeautifulSoup(row_content, "html.parser").find("body")

if soup is None:
return
# Remove script and style tags
for script_or_style in soup(["script", "style"]):
script_or_style.extract()

# Get text and handle whitespaces
text = " ".join(soup.stripped_strings)
# Need to truncate because in some cases the token size
# exceeding the max token size. Better solution can be get summary and ingest it.
return truncate_tokens(text, "gpt-3.5-turbo", 7692)


def fetch_url_content(url) -> str | None:
"""
Fetches the content of a URL.
param url (str): The URL to fetch content from.
"""
try:
response = requests.get(url)
response.raise_for_status() # Raise an HTTPError for bad responses
return response.content
except requests.RequestException:
logger.info("Error fetching content for %s: %s", url, url)
return None


def process_url(url: str, doc_source: str = "") -> dict | None:
"""
Process a URL by fetching its content, cleaning it, and generating a unique identifier (SHA) based on the cleaned content.
param url (str): The URL to be processed.
"""
content = fetch_url_content(url)
if content is not None:
cleaned_content = clean_content(content)
sha = generate_uuid5(cleaned_content)
return {"docSource": doc_source, "sha": sha, "content": cleaned_content, "docLink": url}


def url_to_df(urls: set[str], doc_source: str = "") -> pd.DataFrame:
"""
Create a DataFrame from a list of URLs by processing each URL and organizing the results.
param urls (list): A list of URLs to be processed.
"""
df_data = [process_url(url, doc_source) for url in urls]
df_data = [entry for entry in df_data if entry is not None] # Remove failed entries
df = pd.DataFrame(df_data)
df = df[["docSource", "sha", "content", "docLink"]] # Reorder columns if needed
return df


def get_forum_df() -> list[pd.DataFrame]:
"""
Retrieves question links from a forum, converts them into a DataFrame, and returns a list containing the DataFrame.
"""
questions_links = get_cutoff_questions("https://forum.astronomer.io/latest")
logger.info(questions_links)
df = url_to_df(questions_links, "astro-forum")
return [df]

0 comments on commit 441def6

Please sign in to comment.