diff --git a/archon-ui-main/src/features/progress/components/CrawlingProgress.tsx b/archon-ui-main/src/features/progress/components/CrawlingProgress.tsx index ca03ecfbf5..8d27435577 100644 --- a/archon-ui-main/src/features/progress/components/CrawlingProgress.tsx +++ b/archon-ui-main/src/features/progress/components/CrawlingProgress.tsx @@ -12,6 +12,7 @@ import { Button } from "../../ui/primitives"; import { cn } from "../../ui/primitives/styles"; import { useCrawlProgressPolling } from "../hooks"; import type { ActiveOperation } from "../types/progress"; +import { isValidHttpUrl } from "../utils/urlValidation"; interface CrawlingProgressProps { onSwitchToBrowse: () => void; @@ -129,6 +130,7 @@ export const CrawlingProgress: React.FC = ({ onSwitchToBr "in_progress", "starting", "initializing", + "discovery", "analyzing", "storing", "source_creation", @@ -245,6 +247,63 @@ export const CrawlingProgress: React.FC = ({ onSwitchToBr )} + {/* Discovery Information */} + {operation.discovered_file && ( +
+
+ Discovery Result + {operation.discovered_file_type && ( + + {operation.discovered_file_type} + + )} +
+ {isValidHttpUrl(operation.discovered_file) ? ( + + {operation.discovered_file} + + ) : ( + + {operation.discovered_file} + + )} +
+ )} + + {/* Linked Files */} + {operation.linked_files && operation.linked_files.length > 0 && ( +
+
+ Following {operation.linked_files.length} Linked File + {operation.linked_files.length > 1 ? "s" : ""} +
+
+ {operation.linked_files.map((file: string, idx: number) => ( + isValidHttpUrl(file) ? ( + + • {file} + + ) : ( + + • {file} + + ) + ))} +
+
+ )} + {/* Current Action or Operation Type Info */} {(operation.current_url || operation.operation_type) && (
diff --git a/archon-ui-main/src/features/progress/types/progress.ts b/archon-ui-main/src/features/progress/types/progress.ts index f129d1913d..c57426b9ca 100644 --- a/archon-ui-main/src/features/progress/types/progress.ts +++ b/archon-ui-main/src/features/progress/types/progress.ts @@ -6,6 +6,7 @@ export type ProgressStatus = | "starting" | "initializing" + | "discovery" | "analyzing" | "crawling" | "processing" @@ -24,7 +25,16 @@ export type ProgressStatus = | "cancelled" | "stopping"; -export type CrawlType = "normal" | "sitemap" | "llms-txt" | "text_file" | "refresh"; +export type CrawlType = + | "normal" + | "sitemap" + | "llms-txt" + | "text_file" + | "refresh" + | "llms_txt_with_linked_files" + | "llms_txt_linked_files" + | "discovery_single_file" + | "discovery_sitemap"; export type UploadType = "document"; export interface BaseProgressData { @@ -48,6 +58,10 @@ export interface CrawlProgressData extends BaseProgressData { codeBlocksFound?: number; totalSummaries?: number; completedSummaries?: number; + // Discovery-related fields + discoveredFile?: string; + discoveredFileType?: string; + linkedFiles?: string[]; originalCrawlParams?: { url: string; knowledge_type?: string; @@ -100,6 +114,10 @@ export interface ActiveOperation { code_examples_found?: number; current_operation?: string; }; + // Discovery information + discovered_file?: string; + discovered_file_type?: string; + linked_files?: string[]; } export interface ActiveOperationsResponse { @@ -127,6 +145,13 @@ export interface ProgressResponse { codeBlocksFound?: number; totalSummaries?: number; completedSummaries?: number; + // Discovery-related fields + discoveredFile?: string; + discovered_file?: string; // Snake case from backend + discoveredFileType?: string; + discovered_file_type?: string; // Snake case from backend + linkedFiles?: string[]; + linked_files?: string[]; // Snake case from backend fileName?: string; fileSize?: number; chunksProcessed?: number; diff --git a/archon-ui-main/src/features/progress/utils/urlValidation.ts b/archon-ui-main/src/features/progress/utils/urlValidation.ts new file mode 100644 index 0000000000..79f70bda7d --- /dev/null +++ b/archon-ui-main/src/features/progress/utils/urlValidation.ts @@ -0,0 +1,44 @@ +/** + * Client-side URL validation utility for discovered files. + * Ensures only safe HTTP/HTTPS URLs are rendered as clickable links. + */ + +const SAFE_PROTOCOLS = ["http:", "https:"]; + +/** + * Validates that a URL is safe to render as a clickable link. + * Only allows http: and https: protocols. + * + * @param url - URL string to validate + * @returns true if URL is safe (http/https), false otherwise + */ +export function isValidHttpUrl(url: string | undefined | null): boolean { + if (!url || typeof url !== "string") { + return false; + } + + // Trim whitespace + const trimmed = url.trim(); + if (!trimmed) { + return false; + } + + try { + const parsed = new URL(trimmed); + + // Only allow http and https protocols + if (!SAFE_PROTOCOLS.includes(parsed.protocol)) { + return false; + } + + // Basic hostname validation (must have at least one dot or be localhost) + if (!parsed.hostname.includes(".") && parsed.hostname !== "localhost") { + return false; + } + + return true; + } catch { + // URL parsing failed - not a valid URL + return false; + } +} diff --git a/python/pyproject.toml b/python/pyproject.toml index 2c036d34eb..16b918ebd8 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -59,6 +59,7 @@ server = [ "pydantic>=2.0.0", "python-dotenv>=1.0.0", "docker>=6.1.0", + "tldextract>=5.0.0", # Logging "logfire>=0.30.0", # Testing (needed for UI-triggered tests) @@ -116,6 +117,7 @@ all = [ "cryptography>=41.0.0", "slowapi>=0.1.9", "docker>=6.1.0", + "tldextract>=5.0.0", "logfire>=0.30.0", # MCP specific (mcp version) "mcp==1.12.2", diff --git a/python/src/server/services/crawling/crawling_service.py b/python/src/server/services/crawling/crawling_service.py index 745f7d93db..01122704d8 100644 --- a/python/src/server/services/crawling/crawling_service.py +++ b/python/src/server/services/crawling/crawling_service.py @@ -11,6 +11,8 @@ from collections.abc import Awaitable, Callable from typing import Any, Optional +import tldextract + from ...config.logfire_config import get_logger, safe_logfire_error, safe_logfire_info from ...utils import get_supabase_client from ...utils.progress.progress_tracker import ProgressTracker @@ -18,12 +20,13 @@ # Import strategies # Import operations +from .discovery_service import DiscoveryService from .document_storage_operations import DocumentStorageOperations -from .page_storage_operations import PageStorageOperations from .helpers.site_config import SiteConfig # Import helpers from .helpers.url_handler import URLHandler +from .page_storage_operations import PageStorageOperations from .progress_mapper import ProgressMapper from .strategies.batch import BatchCrawlStrategy from .strategies.recursive import RecursiveCrawlStrategy @@ -37,6 +40,34 @@ _orchestration_lock: asyncio.Lock | None = None +def get_root_domain(host: str) -> str: + """ + Extract the root domain from a hostname using tldextract. + Handles multi-part public suffixes correctly (e.g., .co.uk, .com.au). + + Args: + host: Hostname to extract root domain from + + Returns: + Root domain (domain + suffix) or original host if extraction fails + + Examples: + - "docs.example.com" -> "example.com" + - "api.example.co.uk" -> "example.co.uk" + - "localhost" -> "localhost" + """ + try: + extracted = tldextract.extract(host) + # Return domain.suffix if both are present + if extracted.domain and extracted.suffix: + return f"{extracted.domain}.{extracted.suffix}" + # Fallback to original host if extraction yields no domain or suffix + return host + except Exception: + # If extraction fails, return original host + return host + + def _ensure_orchestration_lock() -> asyncio.Lock: global _orchestration_lock if _orchestration_lock is None: @@ -99,6 +130,7 @@ def __init__(self, crawler=None, supabase_client=None, progress_id=None): # Initialize operations self.doc_storage_ops = DocumentStorageOperations(self.supabase_client) + self.discovery_service = DiscoveryService() self.page_storage_ops = PageStorageOperations(self.supabase_client) # Track progress state across all stages to prevent UI resets @@ -196,13 +228,16 @@ async def crawl_single_page(self, url: str, retry_count: int = 3) -> dict[str, A ) async def crawl_markdown_file( - self, url: str, progress_callback: Callable[[str, int, str], Awaitable[None]] | None = None + self, url: str, progress_callback: Callable[[str, int, str], Awaitable[None]] | None = None, + start_progress: int = 10, end_progress: int = 20 ) -> list[dict[str, Any]]: """Crawl a .txt or markdown file.""" return await self.single_page_strategy.crawl_markdown_file( url, self.url_handler.transform_github_url, progress_callback, + start_progress, + end_progress, ) def parse_sitemap(self, sitemap_url: str) -> list[str]: @@ -351,15 +386,102 @@ async def update_mapped_progress( # Check for cancellation before proceeding self._check_cancellation() - # Analyzing stage - report initial page count (at least 1) - await update_mapped_progress( - "analyzing", 50, f"Analyzing URL type for {url}", - total_pages=1, # We know we have at least the start URL - processed_pages=0 + # Discovery phase - find the single best related file + discovered_urls = [] + # Skip discovery if the URL itself is already a discovery target (sitemap, llms file, etc.) + is_already_discovery_target = ( + self.url_handler.is_sitemap(url) or + self.url_handler.is_llms_variant(url) or + self.url_handler.is_robots_txt(url) or + self.url_handler.is_well_known_file(url) or + self.url_handler.is_txt(url) # Also skip for any .txt file that user provides directly ) - # Detect URL type and perform crawl - crawl_results, crawl_type = await self._crawl_by_url_type(url, request) + if is_already_discovery_target: + safe_logfire_info(f"Skipping discovery - URL is already a discovery target file: {url}") + + if request.get("auto_discovery", True) and not is_already_discovery_target: # Default enabled, but skip if already a discovery file + await update_mapped_progress( + "discovery", 25, f"Discovering best related file for {url}", current_url=url + ) + try: + # Offload potential sync I/O to avoid blocking the event loop + discovered_file = await asyncio.to_thread(self.discovery_service.discover_files, url) + + # Add the single best discovered file to crawl list + if discovered_file: + safe_logfire_info(f"Discovery found file: {discovered_file}") + # Filter through is_binary_file() check like existing code + if not self.url_handler.is_binary_file(discovered_file): + discovered_urls.append(discovered_file) + safe_logfire_info(f"Adding discovered file to crawl: {discovered_file}") + + # Determine file type for user feedback + discovered_file_type = "unknown" + if self.url_handler.is_llms_variant(discovered_file): + discovered_file_type = "llms.txt" + elif self.url_handler.is_sitemap(discovered_file): + discovered_file_type = "sitemap" + elif self.url_handler.is_robots_txt(discovered_file): + discovered_file_type = "robots.txt" + + await update_mapped_progress( + "discovery", 100, + f"Discovery completed: found {discovered_file_type} file", + current_url=url, + discovered_file=discovered_file, + discovered_file_type=discovered_file_type + ) + else: + safe_logfire_info(f"Skipping binary file: {discovered_file}") + else: + safe_logfire_info(f"Discovery found no files for {url}") + await update_mapped_progress( + "discovery", 100, + "Discovery completed: no special files found, will crawl main URL", + current_url=url + ) + + except Exception as e: + safe_logfire_error(f"Discovery phase failed: {e}") + # Continue with regular crawl even if discovery fails + await update_mapped_progress( + "discovery", 100, "Discovery phase failed, continuing with regular crawl", current_url=url + ) + + # Analyzing stage - determine what to crawl + if discovered_urls: + # Discovery found a file - crawl ONLY the discovered file, not the main URL + total_urls_to_crawl = len(discovered_urls) + await update_mapped_progress( + "analyzing", 50, f"Analyzing discovered file: {discovered_urls[0]}", + total_pages=total_urls_to_crawl, + processed_pages=0 + ) + + # Crawl only the discovered file with discovery context + discovered_url = discovered_urls[0] + safe_logfire_info(f"Crawling discovered file instead of main URL: {discovered_url}") + + # Mark this as a discovery target for domain filtering + discovery_request = request.copy() + discovery_request["is_discovery_target"] = True + discovery_request["original_domain"] = self.url_handler.get_base_url(discovered_url) + + crawl_results, crawl_type = await self._crawl_by_url_type(discovered_url, discovery_request) + + else: + # No discovery - crawl the main URL normally + total_urls_to_crawl = 1 + await update_mapped_progress( + "analyzing", 50, f"Analyzing URL type for {url}", + total_pages=total_urls_to_crawl, + processed_pages=0 + ) + + # Crawl the main URL + safe_logfire_info(f"No discovery file found, crawling main URL: {url}") + crawl_results, crawl_type = await self._crawl_by_url_type(url, request) # Update progress tracker with crawl type if self.progress_tracker and crawl_type: @@ -531,7 +653,7 @@ async def code_progress_callback(data: dict): logger.error("Code extraction failed, continuing crawl without code examples", exc_info=True) safe_logfire_error(f"Code extraction failed | error={e}") code_examples_count = 0 - + # Report code extraction failure to progress tracker if self.progress_tracker: await self.progress_tracker.update( @@ -628,6 +750,66 @@ async def code_progress_callback(data: dict): f"Unregistered orchestration service on error | progress_id={self.progress_id}" ) + def _is_same_domain(self, url: str, base_domain: str) -> bool: + """ + Check if a URL belongs to the same domain as the base domain. + + Args: + url: URL to check + base_domain: Base domain URL to compare against + + Returns: + True if the URL is from the same domain + """ + try: + from urllib.parse import urlparse + u, b = urlparse(url), urlparse(base_domain) + url_host = (u.hostname or "").lower() + base_host = (b.hostname or "").lower() + return bool(url_host) and url_host == base_host + except Exception: + # If parsing fails, be conservative and exclude the URL + return False + + def _is_same_domain_or_subdomain(self, url: str, base_domain: str) -> bool: + """ + Check if a URL belongs to the same root domain or subdomain. + + Examples: + - docs.supabase.com matches supabase.com (subdomain) + - api.supabase.com matches supabase.com (subdomain) + - supabase.com matches supabase.com (exact match) + - external.com does NOT match supabase.com + + Args: + url: URL to check + base_domain: Base domain URL to compare against + + Returns: + True if the URL is from the same root domain or subdomain + """ + try: + from urllib.parse import urlparse + u, b = urlparse(url), urlparse(base_domain) + url_host = (u.hostname or "").lower() + base_host = (b.hostname or "").lower() + + if not url_host or not base_host: + return False + + # Exact match + if url_host == base_host: + return True + + # Check if url_host is a subdomain of base_host using tldextract + url_root = get_root_domain(url_host) + base_root = get_root_domain(base_host) + + return url_root == base_root + except Exception: + # If parsing fails, be conservative and exclude the URL + return False + def _is_self_link(self, link: str, base_url: str) -> bool: """ Check if a link is a self-referential link to the base URL. @@ -700,6 +882,63 @@ async def update_crawl_progress(stage_progress: int, message: str, **kwargs): if crawl_results and len(crawl_results) > 0: content = crawl_results[0].get('markdown', '') if self.url_handler.is_link_collection_file(url, content): + # If this file was selected by discovery, check if it's an llms.txt file + if request.get("is_discovery_target"): + # Check if this is an llms.txt file (not sitemap or other discovery targets) + is_llms_file = self.url_handler.is_llms_variant(url) + + if is_llms_file: + logger.info(f"Discovery llms.txt mode: following ALL same-domain links from {url}") + + # Extract all links from the file + extracted_links_with_text = self.url_handler.extract_markdown_links_with_text(content, url) + + # Filter for same-domain links (all types, not just llms.txt) + same_domain_links = [] + if extracted_links_with_text: + original_domain = request.get("original_domain") + if original_domain: + for link, text in extracted_links_with_text: + # Check same domain/subdomain for ALL links + if self._is_same_domain_or_subdomain(link, original_domain): + same_domain_links.append((link, text)) + logger.debug(f"Found same-domain link: {link}") + + if same_domain_links: + # Build mapping and extract just URLs + url_to_link_text = dict(same_domain_links) + extracted_urls = [link for link, _ in same_domain_links] + + logger.info(f"Following {len(extracted_urls)} same-domain links from llms.txt") + + # Notify user about linked files being crawled + await update_crawl_progress( + 60, # 60% of crawling stage + f"Found {len(extracted_urls)} links in llms.txt, crawling them now...", + crawl_type="llms_txt_linked_files", + linked_files=extracted_urls + ) + + # Crawl all same-domain links from llms.txt (no recursion, just one level) + batch_results = await self.crawl_batch_with_progress( + extracted_urls, + max_concurrent=request.get('max_concurrent'), + progress_callback=await self._create_crawl_progress_callback("crawling"), + link_text_fallbacks=url_to_link_text, + ) + + # Combine original llms.txt with linked pages + crawl_results.extend(batch_results) + crawl_type = "llms_txt_with_linked_pages" + logger.info(f"llms.txt crawling completed: {len(crawl_results)} total pages (1 llms.txt + {len(batch_results)} linked pages)") + return crawl_results, crawl_type + + # For non-llms.txt discovery targets (sitemaps, robots.txt), keep single-file mode + logger.info(f"Discovery single-file mode: skipping link extraction for {url}") + crawl_type = "discovery_single_file" + logger.info(f"Discovery file crawling completed: {len(crawl_results)} result") + return crawl_results, crawl_type + # Extract links WITH text from the content extracted_links_with_text = self.url_handler.extract_markdown_links_with_text(content, url) @@ -714,6 +953,19 @@ async def update_crawl_progress(stage_progress: int, message: str, **kwargs): if self_filtered_count > 0: logger.info(f"Filtered out {self_filtered_count} self-referential links from {original_count} extracted links") + # For discovery targets, only follow same-domain links + if extracted_links_with_text and request.get("is_discovery_target"): + original_domain = request.get("original_domain") + if original_domain: + original_count = len(extracted_links_with_text) + extracted_links_with_text = [ + (link, text) for link, text in extracted_links_with_text + if self._is_same_domain(link, original_domain) + ] + domain_filtered_count = original_count - len(extracted_links_with_text) + if domain_filtered_count > 0: + safe_logfire_info(f"Discovery mode: filtered out {domain_filtered_count} external links, keeping {len(extracted_links_with_text)} same-domain links") + # Filter out binary files (PDFs, images, archives, etc.) to avoid wasteful crawling if extracted_links_with_text: original_count = len(extracted_links_with_text) @@ -724,26 +976,39 @@ async def update_crawl_progress(stage_progress: int, message: str, **kwargs): if extracted_links_with_text: # Build mapping of URL -> link text for title fallback - url_to_link_text = {link: text for link, text in extracted_links_with_text} + url_to_link_text = dict(extracted_links_with_text) extracted_links = [link for link, _ in extracted_links_with_text] - # Crawl the extracted links using batch crawling - logger.info(f"Crawling {len(extracted_links)} extracted links from {url}") - batch_results = await self.crawl_batch_with_progress( - extracted_links, - max_concurrent=request.get('max_concurrent'), # None -> use DB settings - progress_callback=await self._create_crawl_progress_callback("crawling"), - link_text_fallbacks=url_to_link_text, # Pass link text for title fallback - ) + # For discovery targets, respect max_depth for same-domain links + max_depth = request.get('max_depth', 2) if request.get("is_discovery_target") else request.get('max_depth', 1) + + if max_depth > 1 and request.get("is_discovery_target"): + # Use recursive crawling to respect depth limit for same-domain links + logger.info(f"Crawling {len(extracted_links)} same-domain links with max_depth={max_depth-1}") + batch_results = await self.crawl_recursive_with_progress( + extracted_links, + max_depth=max_depth - 1, # Reduce depth since we're already 1 level deep + max_concurrent=request.get('max_concurrent'), + progress_callback=await self._create_crawl_progress_callback("crawling"), + ) + else: + # Use normal batch crawling (with link text fallbacks) + logger.info(f"Crawling {len(extracted_links)} extracted links from {url}") + batch_results = await self.crawl_batch_with_progress( + extracted_links, + max_concurrent=request.get('max_concurrent'), # None -> use DB settings + progress_callback=await self._create_crawl_progress_callback("crawling"), + link_text_fallbacks=url_to_link_text, # Pass link text for title fallback + ) # Combine original text file results with batch results crawl_results.extend(batch_results) crawl_type = "link_collection_with_crawled_links" logger.info(f"Link collection crawling completed: {len(crawl_results)} total results (1 text file + {len(batch_results)} extracted links)") - else: - logger.info(f"No valid links found in link collection file: {url}") - logger.info(f"Text file crawling completed: {len(crawl_results)} results") + else: + logger.info(f"No valid links found in link collection file: {url}") + logger.info(f"Text file crawling completed: {len(crawl_results)} results") elif self.url_handler.is_sitemap(url): # Handle sitemaps @@ -753,6 +1018,20 @@ async def update_crawl_progress(stage_progress: int, message: str, **kwargs): "Detected sitemap, parsing URLs...", crawl_type=crawl_type ) + + # If this sitemap was selected by discovery, just return the sitemap itself (single-file mode) + if request.get("is_discovery_target"): + logger.info(f"Discovery single-file mode: returning sitemap itself without crawling URLs from {url}") + crawl_type = "discovery_sitemap" + # Return the sitemap file as the result + crawl_results = [{ + 'url': url, + 'markdown': f"# Sitemap: {url}\n\nThis is a sitemap file discovered and returned in single-file mode.", + 'title': f"Sitemap - {self.url_handler.extract_display_name(url)}", + 'crawl_type': crawl_type + }] + return crawl_results, crawl_type + sitemap_urls = self.parse_sitemap(url) if sitemap_urls: diff --git a/python/src/server/services/crawling/discovery_service.py b/python/src/server/services/crawling/discovery_service.py new file mode 100644 index 0000000000..103a277296 --- /dev/null +++ b/python/src/server/services/crawling/discovery_service.py @@ -0,0 +1,558 @@ +""" +Discovery Service for Automatic File Detection + +Handles automatic discovery and parsing of llms.txt, sitemap.xml, and related files +to enhance crawling capabilities with priority-based discovery methods. +""" + +import ipaddress +import socket +from html.parser import HTMLParser +from urllib.parse import urljoin, urlparse + +import requests + +from ...config.logfire_config import get_logger + +logger = get_logger(__name__) + + +class SitemapHTMLParser(HTMLParser): + """HTML parser for extracting sitemap references from link and meta tags.""" + + def __init__(self): + super().__init__() + self.sitemaps = [] + + def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]): + """Handle start tags to find sitemap references.""" + attrs_dict = {k.lower(): v for k, v in attrs if v is not None} + + # Check + if tag == 'link': + rel = attrs_dict.get('rel', '').lower() + # Handle multi-valued rel attributes (space-separated) + rel_values = rel.split() if rel else [] + if 'sitemap' in rel_values: + href = attrs_dict.get('href') + if href: + self.sitemaps.append(('link', href)) + + # Check + elif tag == 'meta': + name = attrs_dict.get('name', '').lower() + if name == 'sitemap': + content = attrs_dict.get('content') + if content: + self.sitemaps.append(('meta', content)) + + +class DiscoveryService: + """Service for discovering related files automatically during crawls.""" + + # Maximum response size to prevent memory exhaustion (10MB default) + MAX_RESPONSE_SIZE = 10 * 1024 * 1024 # 10 MB + + # Global priority order - select ONE best file from all categories + # Based on actual usage research - only includes files commonly found in the wild + DISCOVERY_PRIORITY = [ + # LLMs files (highest priority - most comprehensive AI guidance) + "llms.txt", # Standard llms.txt spec - widely adopted + "llms-full.txt", # Part of llms.txt spec - comprehensive content + # Sitemap files (structural crawling guidance) + "sitemap.xml", # Universal standard for site structure + # Robots file (basic crawling rules) + "robots.txt", # Universal standard for crawl directives + # Well-known variants (alternative locations per RFC 8615) + ".well-known/ai.txt", + ".well-known/llms.txt", + ".well-known/sitemap.xml" + ] + + # Known file extensions for path detection + FILE_EXTENSIONS = { + '.html', '.htm', '.xml', '.json', '.txt', '.md', '.csv', + '.rss', '.yaml', '.yml', '.pdf', '.zip' + } + + def discover_files(self, base_url: str) -> str | None: + """ + Main discovery orchestrator - selects ONE best file across all categories. + All files contain similar AI/crawling guidance, so we only need the best one. + + Args: + base_url: Base URL to discover files for + + Returns: + Single best URL found, or None if no files discovered + """ + try: + logger.info(f"Starting single-file discovery for {base_url}") + + # Extract directory path from base URL + base_dir = self._extract_directory(base_url) + + # Try each file in priority order + for filename in self.DISCOVERY_PRIORITY: + discovered_url = self._try_locations(base_url, base_dir, filename) + if discovered_url: + logger.info(f"Discovery found best file: {discovered_url}") + return discovered_url + + # Fallback: Check HTML meta tags for sitemap references + html_sitemaps = self._parse_html_meta_tags(base_url) + if html_sitemaps: + best_file = html_sitemaps[0] + logger.info(f"Discovery found best file from HTML meta tags: {best_file}") + return best_file + + logger.info(f"Discovery completed for {base_url}: no files found") + return None + + except Exception: + logger.exception(f"Unexpected error during discovery for {base_url}") + return None + + def _extract_directory(self, base_url: str) -> str: + """ + Extract directory path from URL, handling both file URLs and directory URLs. + + Args: + base_url: URL to extract directory from + + Returns: + Directory path (without trailing slash) + """ + parsed = urlparse(base_url) + base_path = parsed.path.rstrip('/') + + # Check if last segment is a file (has known extension) + last_segment = base_path.split('/')[-1] if base_path else '' + has_file_extension = any(last_segment.lower().endswith(ext) for ext in self.FILE_EXTENSIONS) + + if has_file_extension: + # Remove filename to get directory + return '/'.join(base_path.split('/')[:-1]) + else: + # Last segment is a directory + return base_path + + def _try_locations(self, base_url: str, base_dir: str, filename: str) -> str | None: + """ + Try different locations for a given filename in priority order. + + Priority: + 1. Same directory as base_url (if not root) + 2. Root level + 3. Common subdirectories (based on file type) + + Args: + base_url: Original base URL + base_dir: Extracted directory path + filename: Filename to search for + + Returns: + URL if file found, None otherwise + """ + parsed = urlparse(base_url) + + # Priority 1: Check same directory (if not root) + if base_dir and base_dir != '/': + same_dir_url = f"{parsed.scheme}://{parsed.netloc}{base_dir}/{filename}" + if self._check_url_exists(same_dir_url): + return same_dir_url + + # Priority 2: Check root level + root_url = urljoin(base_url, filename) + if self._check_url_exists(root_url): + return root_url + + # Priority 3: Check common subdirectories + subdirs = self._get_subdirs_for_file(base_dir, filename) + for subdir in subdirs: + subdir_url = urljoin(base_url, f"{subdir}/{filename}") + if self._check_url_exists(subdir_url): + return subdir_url + + return None + + def _get_subdirs_for_file(self, base_dir: str, filename: str) -> list[str]: + """ + Get relevant subdirectories to check based on file type. + + Args: + base_dir: Base directory path + filename: Filename being searched for + + Returns: + List of subdirectory names to check + """ + subdirs = [] + + # Include base directory name if available + if base_dir and base_dir != '/': + base_dir_name = base_dir.split('/')[-1] + if base_dir_name: + subdirs.append(base_dir_name) + + # Add type-specific subdirectories + if filename.startswith('llms') or filename.endswith('.txt') or filename.endswith('.md'): + # LLMs files commonly in these locations + subdirs.extend(["docs", "static", "public", "assets", "doc", "api"]) + elif filename.endswith('.xml') and not filename.startswith('.well-known'): + # Sitemap files commonly in these locations + subdirs.extend(["docs", "sitemaps", "sitemap", "xml", "feed"]) + + return subdirs + + def _is_safe_ip(self, ip_str: str) -> bool: + """ + Check if an IP address is safe (not private, loopback, link-local, or cloud metadata). + + Args: + ip_str: IP address string to check + + Returns: + True if IP is safe for outbound requests, False otherwise + """ + try: + ip = ipaddress.ip_address(ip_str) + + # Block private networks + if ip.is_private: + logger.warning(f"Blocked private IP address: {ip_str}") + return False + + # Block loopback (127.0.0.0/8, ::1) + if ip.is_loopback: + logger.warning(f"Blocked loopback IP address: {ip_str}") + return False + + # Block link-local (169.254.0.0/16, fe80::/10) + if ip.is_link_local: + logger.warning(f"Blocked link-local IP address: {ip_str}") + return False + + # Block multicast + if ip.is_multicast: + logger.warning(f"Blocked multicast IP address: {ip_str}") + return False + + # Block reserved ranges + if ip.is_reserved: + logger.warning(f"Blocked reserved IP address: {ip_str}") + return False + + # Additional explicit checks for cloud metadata services + # AWS metadata service + if str(ip) == "169.254.169.254": + logger.warning(f"Blocked AWS metadata service IP: {ip_str}") + return False + + # GCP metadata service + if str(ip) == "169.254.169.254": + logger.warning(f"Blocked GCP metadata service IP: {ip_str}") + return False + + return True + + except ValueError: + logger.warning(f"Invalid IP address format: {ip_str}") + return False + + def _resolve_and_validate_hostname(self, hostname: str) -> bool: + """ + Resolve hostname to IP and validate it's safe. + + Args: + hostname: Hostname to resolve and validate + + Returns: + True if hostname resolves to safe IPs only, False otherwise + """ + try: + # Resolve hostname to IP addresses + addr_info = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM) + + # Check all resolved IPs + for info in addr_info: + ip_str = info[4][0] + if not self._is_safe_ip(ip_str): + logger.warning(f"Hostname {hostname} resolves to unsafe IP {ip_str}") + return False + + return True + + except socket.gaierror as e: + logger.warning(f"DNS resolution failed for {hostname}: {e}") + return False + except Exception as e: + logger.warning(f"Error resolving hostname {hostname}: {e}") + return False + + def _check_url_exists(self, url: str) -> bool: + """ + Check if a URL exists and returns a successful response. + Includes SSRF protection by validating hostnames and blocking private IPs. + + Args: + url: URL to check + + Returns: + True if URL returns 200, False otherwise + """ + try: + # Parse URL to extract hostname + parsed = urlparse(url) + if not parsed.scheme or not parsed.netloc: + logger.warning(f"Invalid URL format: {url}") + return False + + # Only allow HTTP/HTTPS + if parsed.scheme not in ('http', 'https'): + logger.warning(f"Blocked non-HTTP(S) scheme: {parsed.scheme}") + return False + + # Validate initial hostname + hostname = parsed.netloc.split(':')[0] # Remove port if present + if not self._resolve_and_validate_hostname(hostname): + logger.warning(f"URL check blocked due to unsafe hostname: {url}") + return False + + # Set safe User-Agent header + headers = { + 'User-Agent': 'Archon-Discovery/1.0 (SSRF-Protected)' + } + + # Create a session with limited redirects + session = requests.Session() + session.max_redirects = 3 + + # Make request with redirect validation + resp = session.get( + url, + timeout=5, + allow_redirects=True, + verify=True, + headers=headers + ) + + try: + # Check if there were redirects (history attribute exists on real responses) + if hasattr(resp, 'history') and resp.history: + logger.debug(f"URL {url} had {len(resp.history)} redirect(s)") + + # Validate final destination + final_url = resp.url + final_parsed = urlparse(final_url) + + # Only allow HTTP/HTTPS for final destination + if final_parsed.scheme not in ('http', 'https'): + logger.warning(f"Blocked redirect to non-HTTP(S) scheme: {final_parsed.scheme}") + return False + + # Validate final hostname + final_hostname = final_parsed.netloc.split(':')[0] + if not self._resolve_and_validate_hostname(final_hostname): + logger.warning(f"Redirect target blocked due to unsafe hostname: {final_url}") + return False + + # Check response status + success = resp.status_code == 200 + logger.debug(f"URL check: {url} -> {resp.status_code} ({'exists' if success else 'not found'})") + return success + + finally: + if hasattr(resp, 'close'): + resp.close() + + except requests.exceptions.TooManyRedirects: + logger.warning(f"Too many redirects for URL: {url}") + return False + except requests.exceptions.Timeout: + logger.debug(f"Timeout checking URL: {url}") + return False + except requests.exceptions.RequestException as e: + logger.debug(f"Request error checking URL {url}: {e}") + return False + except Exception as e: + logger.warning(f"Unexpected error checking URL {url}: {e}", exc_info=True) + return False + + def _parse_robots_txt(self, base_url: str) -> list[str]: + """ + Extract sitemap URLs from robots.txt. + + Args: + base_url: Base URL to check robots.txt for + + Returns: + List of sitemap URLs found in robots.txt + """ + sitemaps: list[str] = [] + + try: + robots_url = urljoin(base_url, "robots.txt") + logger.info(f"Checking robots.txt at {robots_url}") + + # Set safe User-Agent header + headers = { + 'User-Agent': 'Archon-Discovery/1.0 (SSRF-Protected)' + } + + resp = requests.get(robots_url, timeout=30, stream=True, verify=True, headers=headers) + + try: + if resp.status_code != 200: + logger.info(f"No robots.txt found: HTTP {resp.status_code}") + return sitemaps + + # Read response with size limit + content = self._read_response_with_limit(resp, robots_url) + + # Parse robots.txt content for sitemap directives + for raw_line in content.splitlines(): + line = raw_line.strip() + if line.lower().startswith("sitemap:"): + sitemap_value = line.split(":", 1)[1].strip() + if sitemap_value: + # Allow absolute and relative sitemap values + if sitemap_value.lower().startswith(("http://", "https://")): + sitemap_url = sitemap_value + else: + # Resolve relative path against base_url + sitemap_url = urljoin(base_url, sitemap_value) + + # Validate scheme is HTTP/HTTPS only + parsed = urlparse(sitemap_url) + if parsed.scheme not in ("http", "https"): + logger.warning(f"Skipping non-HTTP(S) sitemap in robots.txt: {sitemap_url}") + continue + + sitemaps.append(sitemap_url) + logger.info(f"Found sitemap in robots.txt: {sitemap_url}") + + finally: + resp.close() + + except requests.exceptions.RequestException: + logger.exception(f"Network error fetching robots.txt from {base_url}") + except ValueError as e: + logger.warning(f"robots.txt too large at {base_url}: {e}") + except Exception: + logger.exception(f"Unexpected error parsing robots.txt from {base_url}") + + return sitemaps + + def _parse_html_meta_tags(self, base_url: str) -> list[str]: + """ + Extract sitemap references from HTML meta tags using proper HTML parsing. + + Args: + base_url: Base URL to check HTML for meta tags + + Returns: + List of sitemap URLs found in HTML meta tags + """ + sitemaps: list[str] = [] + + try: + logger.info(f"Checking HTML meta tags for sitemaps at {base_url}") + + # Set safe User-Agent header + headers = { + 'User-Agent': 'Archon-Discovery/1.0 (SSRF-Protected)' + } + + resp = requests.get(base_url, timeout=30, stream=True, verify=True, headers=headers) + + try: + if resp.status_code != 200: + logger.debug(f"Could not fetch HTML for meta tag parsing: HTTP {resp.status_code}") + return sitemaps + + # Read response with size limit + content = self._read_response_with_limit(resp, base_url) + + # Parse HTML using proper HTML parser + parser = SitemapHTMLParser() + try: + parser.feed(content) + except Exception as e: + logger.warning(f"HTML parsing error for {base_url}: {e}") + return sitemaps + + # Process found sitemaps + for tag_type, url in parser.sitemaps: + # Resolve relative URLs + sitemap_url = urljoin(base_url, url.strip()) + + # Validate scheme is HTTP/HTTPS + parsed = urlparse(sitemap_url) + if parsed.scheme not in ("http", "https"): + logger.debug(f"Skipping non-HTTP(S) sitemap URL: {sitemap_url}") + continue + + sitemaps.append(sitemap_url) + logger.info(f"Found sitemap in HTML {tag_type} tag: {sitemap_url}") + + finally: + resp.close() + + except requests.exceptions.RequestException: + logger.exception(f"Network error fetching HTML from {base_url}") + except ValueError as e: + logger.warning(f"HTML response too large at {base_url}: {e}") + except Exception: + logger.exception(f"Unexpected error parsing HTML meta tags from {base_url}") + + return sitemaps + + def _read_response_with_limit(self, response: requests.Response, url: str, max_size: int | None = None) -> str: + """ + Read response content with size limit to prevent memory exhaustion. + + Args: + response: The response object to read from + url: URL being read (for logging) + max_size: Maximum bytes to read (defaults to MAX_RESPONSE_SIZE) + + Returns: + Response text content + + Raises: + ValueError: If response exceeds size limit + """ + if max_size is None: + max_size = self.MAX_RESPONSE_SIZE + + try: + chunks = [] + total_size = 0 + + # Read response in chunks to enforce size limit + for chunk in response.iter_content(chunk_size=8192, decode_unicode=False): + if chunk: + total_size += len(chunk) + if total_size > max_size: + response.close() + size_mb = max_size / (1024 * 1024) + logger.warning( + f"Response size exceeded limit of {size_mb:.1f}MB for {url}, " + f"received {total_size / (1024 * 1024):.1f}MB" + ) + raise ValueError(f"Response size exceeds {size_mb:.1f}MB limit") + chunks.append(chunk) + + # Decode the complete response + content_bytes = b''.join(chunks) + encoding = response.encoding or 'utf-8' + try: + return content_bytes.decode(encoding) + except UnicodeDecodeError: + # Fallback to utf-8 with error replacement + return content_bytes.decode('utf-8', errors='replace') + + except Exception: + response.close() + raise diff --git a/python/src/server/services/crawling/helpers/url_handler.py b/python/src/server/services/crawling/helpers/url_handler.py index 3cf0f1dc40..f243c2ab00 100644 --- a/python/src/server/services/crawling/helpers/url_handler.py +++ b/python/src/server/services/crawling/helpers/url_handler.py @@ -6,8 +6,8 @@ import hashlib import re -from urllib.parse import urlparse, urljoin from typing import List, Optional +from urllib.parse import urljoin, urlparse from ....config.logfire_config import get_logger @@ -36,8 +36,8 @@ def is_sitemap(url: str) -> bool: except Exception as e: logger.warning(f"Error checking if URL is sitemap: {e}") return False - - @staticmethod + + @staticmethod def is_markdown(url: str) -> bool: """ Check if a URL points to a markdown file (.md, .mdx, .markdown). @@ -277,9 +277,9 @@ def generate_unique_source_id(url: str) -> str: # Fallback: use a hash of the error message + url to still get something unique fallback = f"error_{redacted}_{str(e)}" return hashlib.sha256(fallback.encode("utf-8")).hexdigest()[:16] - + @staticmethod - def extract_markdown_links(content: str, base_url: Optional[str] = None) -> List[str]: + def extract_markdown_links(content: str, base_url: str | None = None) -> list[str]: """ Extract markdown-style links from text content. @@ -385,9 +385,9 @@ def _clean_url(u: str) -> str: except Exception as e: logger.error(f"Error extracting markdown links with text: {e}", exc_info=True) return [] - + @staticmethod - def is_link_collection_file(url: str, content: Optional[str] = None) -> bool: + def is_link_collection_file(url: str, content: str | None = None) -> bool: """ Check if a URL/file appears to be a link collection file like llms.txt. @@ -402,56 +402,55 @@ def is_link_collection_file(url: str, content: Optional[str] = None) -> bool: # Extract filename from URL parsed = urlparse(url) filename = parsed.path.split('/')[-1].lower() - + # Check for specific link collection filenames # Note: "full-*" or "*-full" patterns are NOT link collections - they contain complete content, not just links + # Only includes commonly used formats found in the wild link_collection_patterns = [ # .txt variants - files that typically contain lists of links 'llms.txt', 'links.txt', 'resources.txt', 'references.txt', - # .md/.mdx/.markdown variants - 'llms.md', 'links.md', 'resources.md', 'references.md', - 'llms.mdx', 'links.mdx', 'resources.mdx', 'references.mdx', - 'llms.markdown', 'links.markdown', 'resources.markdown', 'references.markdown', ] - + # Direct filename match if filename in link_collection_patterns: logger.info(f"Detected link collection file by filename: {filename}") return True - + # Pattern-based detection for variations, but exclude "full" variants # Only match files that are likely link collections, not complete content files - if filename.endswith(('.txt', '.md', '.mdx', '.markdown')): - # Exclude files with "full" in the name - these typically contain complete content, not just links - if 'full' not in filename: + if filename.endswith('.txt'): + # Exclude files with "full" as standalone token (avoid false positives like "helpful.md") + import re + if not re.search(r'(^|[._-])full([._-]|$)', filename): # Match files that start with common link collection prefixes base_patterns = ['llms', 'links', 'resources', 'references'] if any(filename.startswith(pattern + '.') or filename.startswith(pattern + '-') for pattern in base_patterns): logger.info(f"Detected potential link collection file: {filename}") return True - + # Content-based detection if content is provided if content: # Never treat "full" variants as link collections to preserve single-page behavior - if 'full' in filename: + import re + if re.search(r'(^|[._-])full([._-]|$)', filename): logger.info(f"Skipping content-based link-collection detection for full-content file: {filename}") return False # Reuse extractor to avoid regex divergence and maintain consistency extracted_links = URLHandler.extract_markdown_links(content, url) total_links = len(extracted_links) - + # Calculate link density (links per 100 characters) content_length = len(content.strip()) if content_length > 0: link_density = (total_links * 100) / content_length - + # If more than 2% of content is links, likely a link collection if link_density > 2.0 and total_links > 3: logger.info(f"Detected link collection by content analysis: {total_links} links, density {link_density:.2f}%") return True - + return False - + except Exception as e: logger.warning(f"Error checking if file is link collection: {e}", exc_info=True) return False @@ -605,3 +604,104 @@ def extract_display_name(url: str) -> str: logger.warning(f"Error extracting display name for {url}: {e}, using URL") # Fallback: return truncated URL return url[:50] + "..." if len(url) > 50 else url + + @staticmethod + def is_robots_txt(url: str) -> bool: + """ + Check if a URL is a robots.txt file with error handling. + + Args: + url: URL to check + + Returns: + True if URL is a robots.txt file, False otherwise + """ + try: + parsed = urlparse(url) + # Normalize to lowercase and ignore query/fragment + path = parsed.path.lower() + # Only detect robots.txt at root level + return path == '/robots.txt' + except Exception as e: + logger.warning(f"Error checking if URL is robots.txt: {e}", exc_info=True) + return False + + @staticmethod + def is_llms_variant(url: str) -> bool: + """ + Check if a URL is a llms.txt/llms.md variant with error handling. + + Matches: + - Exact filename matches: llms.txt, llms-full.txt, llms.md, etc. + - Files in /llms/ directories: /llms/guides.txt, /llms/swift.txt, etc. + + Args: + url: URL to check + + Returns: + True if URL is a llms file variant, False otherwise + """ + try: + parsed = urlparse(url) + # Normalize to lowercase and ignore query/fragment + path = parsed.path.lower() + filename = path.split('/')[-1] if '/' in path else path + + # Check for exact llms file variants (only standard spec files) + llms_variants = ['llms.txt', 'llms-full.txt'] + if filename in llms_variants: + return True + + # Check for .txt files in /llms/ directory (e.g., /llms/guides.txt, /llms/swift.txt) + if '/llms/' in path and path.endswith('.txt'): + return True + + return False + except Exception as e: + logger.warning(f"Error checking if URL is llms variant: {e}", exc_info=True) + return False + + @staticmethod + def is_well_known_file(url: str) -> bool: + """ + Check if a URL is a .well-known/* file with error handling. + Per RFC 8615, the path is case-sensitive and must be lowercase. + + Args: + url: URL to check + + Returns: + True if URL is a .well-known file, False otherwise + """ + try: + parsed = urlparse(url) + # RFC 8615: path segments are case-sensitive, must be lowercase + path = parsed.path + # Only detect .well-known files at root level + return path.startswith('/.well-known/') and path.count('/.well-known/') == 1 + except Exception as e: + logger.warning(f"Error checking if URL is well-known file: {e}", exc_info=True) + return False + + @staticmethod + def get_base_url(url: str) -> str: + """ + Extract base domain URL for discovery with error handling. + + Args: + url: URL to extract base from + + Returns: + Base URL (scheme + netloc) or original URL if extraction fails + """ + try: + parsed = urlparse(url) + # Ensure we have scheme and netloc + if parsed.scheme and parsed.netloc: + return f"{parsed.scheme}://{parsed.netloc}" + else: + logger.warning(f"URL missing scheme or netloc: {url}") + return url + except Exception as e: + logger.warning(f"Error extracting base URL from {url}: {e}", exc_info=True) + return url diff --git a/python/src/server/services/crawling/progress_mapper.py b/python/src/server/services/crawling/progress_mapper.py index 5efe24938f..81d56336c5 100644 --- a/python/src/server/services/crawling/progress_mapper.py +++ b/python/src/server/services/crawling/progress_mapper.py @@ -18,14 +18,18 @@ class ProgressMapper: "error": (-1, -1), # Special case for errors "cancelled": (-1, -1), # Special case for cancellation "completed": (100, 100), + "complete": (100, 100), # Alias # Crawl-specific stages - rebalanced based on actual time taken "analyzing": (1, 3), # URL analysis is quick - "crawling": (3, 15), # Crawling can take time for deep/many URLs + "discovery": (3, 4), # File discovery is quick (new stage for discovery feature) + "crawling": (4, 15), # Crawling can take time for deep/many URLs "processing": (15, 20), # Content processing/chunking "source_creation": (20, 25), # DB operations "document_storage": (25, 40), # Embeddings generation takes significant time "code_extraction": (40, 90), # Code extraction + summaries - still longest but more balanced + "code_storage": (40, 90), # Alias + "extracting": (40, 90), # Alias for code_extraction "finalization": (90, 100), # Final steps and cleanup # Upload-specific stages @@ -65,7 +69,7 @@ def map_progress(self, stage: str, stage_progress: float) -> int: start, end = self.STAGE_RANGES[stage] # Handle completion - if stage == "completed": + if stage in ["completed", "complete"]: self.last_overall_progress = 100 return 100 diff --git a/python/src/server/services/crawling/strategies/single_page.py b/python/src/server/services/crawling/strategies/single_page.py index 58610d0130..96ea5bb55f 100644 --- a/python/src/server/services/crawling/strategies/single_page.py +++ b/python/src/server/services/crawling/strategies/single_page.py @@ -229,17 +229,43 @@ async def crawl_markdown_file( ) -> list[dict[str, Any]]: """ Crawl a .txt or markdown file with comprehensive error handling and progress reporting. - + Args: url: URL of the text/markdown file transform_url_func: Function to transform URLs (e.g., GitHub URLs) progress_callback: Optional callback for progress updates - start_progress: Starting progress percentage - end_progress: Ending progress percentage - + start_progress: Starting progress percentage (must be 0-100) + end_progress: Ending progress percentage (must be 0-100 and > start_progress) + Returns: List containing the crawled document + + Raises: + ValueError: If start_progress or end_progress are invalid """ + # Validate progress parameters before any async work or progress reporting + if not isinstance(start_progress, (int, float)) or not isinstance(end_progress, (int, float)): + raise ValueError( + f"start_progress and end_progress must be int or float, " + f"got start_progress={type(start_progress).__name__}, end_progress={type(end_progress).__name__}" + ) + + if not (0 <= start_progress <= 100): + raise ValueError( + f"start_progress must be in range [0, 100], got {start_progress}" + ) + + if not (0 <= end_progress <= 100): + raise ValueError( + f"end_progress must be in range [0, 100], got {end_progress}" + ) + + if start_progress >= end_progress: + raise ValueError( + f"start_progress must be less than end_progress, " + f"got start_progress={start_progress}, end_progress={end_progress}" + ) + try: # Transform GitHub URLs to raw content URLs if applicable original_url = url diff --git a/python/tests/progress_tracking/test_progress_mapper.py b/python/tests/progress_tracking/test_progress_mapper.py index d573975595..37532f8817 100644 --- a/python/tests/progress_tracking/test_progress_mapper.py +++ b/python/tests/progress_tracking/test_progress_mapper.py @@ -13,109 +13,119 @@ class TestProgressMapper: def test_initialization(self): """Test ProgressMapper initialization""" mapper = ProgressMapper() - + assert mapper.last_overall_progress == 0 assert mapper.current_stage == "starting" - + def test_map_progress_basic(self): """Test basic progress mapping""" mapper = ProgressMapper() - + # Starting stage (0-1%) progress = mapper.map_progress("starting", 50) assert progress == 0 # 50% of 0-1 range - + # Analyzing stage (1-3%) progress = mapper.map_progress("analyzing", 50) assert progress == 2 # 1 + (50% of 2) = 2 - - # Crawling stage (3-15%) + + # Discovery stage (3-4%) - NEW TEST FOR DISCOVERY FEATURE + progress = mapper.map_progress("discovery", 50) + assert progress == 4 # 3 + (50% of 1) = 3.5 -> 4 (rounds up) + + # Crawling stage (4-15%) progress = mapper.map_progress("crawling", 50) - assert progress == 9 # 3 + (50% of 12) = 9 - + assert progress == 10 # 4 + (50% of 11) = 9.5 -> 10 (rounds up) + def test_progress_never_goes_backwards(self): """Test that progress never decreases""" mapper = ProgressMapper() - - # Move to 50% of crawling (3-15%) = 9% + + # Move to 50% of crawling (4-15%) = 9.5 -> 10% progress1 = mapper.map_progress("crawling", 50) - assert progress1 == 9 - - # Try to go back to analyzing (1-3%) - should stay at 9% + assert progress1 == 10 + + # Try to go back to analyzing (1-3%) - should stay at 10% progress2 = mapper.map_progress("analyzing", 100) - assert progress2 == 9 # Should not go backwards - + assert progress2 == 10 # Should not go backwards + # Can move forward to document_storage progress3 = mapper.map_progress("document_storage", 50) assert progress3 == 32 # 25 + (50% of 15) = 32.5 -> 32 - + def test_completion_handling(self): """Test completion status handling""" mapper = ProgressMapper() - + # Jump straight to completed progress = mapper.map_progress("completed", 0) assert progress == 100 - + # Any percentage at completed should be 100 progress = mapper.map_progress("completed", 50) assert progress == 100 - + + # Test alias 'complete' + mapper2 = ProgressMapper() + progress = mapper2.map_progress("complete", 0) + assert progress == 100 + def test_error_handling(self): """Test error status handling - preserves last known progress""" mapper = ProgressMapper() - + # Error with no prior progress should return 0 (initial state) progress = mapper.map_progress("error", 50) assert progress == 0 - + # Set some progress first, then error should preserve it mapper.map_progress("crawling", 50) # Should map to somewhere in the crawling range current_progress = mapper.last_overall_progress error_progress = mapper.map_progress("error", 50) assert error_progress == current_progress # Should preserve the progress - + def test_cancelled_handling(self): """Test cancelled status handling - preserves last known progress""" mapper = ProgressMapper() - + # Cancelled with no prior progress should return 0 (initial state) progress = mapper.map_progress("cancelled", 50) assert progress == 0 - + # Set some progress first, then cancelled should preserve it mapper.map_progress("crawling", 75) # Should map to somewhere in the crawling range current_progress = mapper.last_overall_progress cancelled_progress = mapper.map_progress("cancelled", 50) assert cancelled_progress == current_progress # Should preserve the progress - + def test_unknown_stage(self): """Test handling of unknown stages""" mapper = ProgressMapper() - + # Set some initial progress mapper.map_progress("crawling", 50) current = mapper.last_overall_progress - + # Unknown stage should maintain current progress progress = mapper.map_progress("unknown_stage", 50) assert progress == current - - def test_stage_ranges(self): - """Test all defined stage ranges""" + + def test_stage_ranges_with_discovery(self): + """Test all defined stage ranges including discovery""" mapper = ProgressMapper() - + # Verify ranges are correctly defined with new balanced values assert mapper.STAGE_RANGES["starting"] == (0, 1) assert mapper.STAGE_RANGES["analyzing"] == (1, 3) - assert mapper.STAGE_RANGES["crawling"] == (3, 15) + assert mapper.STAGE_RANGES["discovery"] == (3, 4) # NEW DISCOVERY STAGE + assert mapper.STAGE_RANGES["crawling"] == (4, 15) assert mapper.STAGE_RANGES["processing"] == (15, 20) assert mapper.STAGE_RANGES["source_creation"] == (20, 25) assert mapper.STAGE_RANGES["document_storage"] == (25, 40) assert mapper.STAGE_RANGES["code_extraction"] == (40, 90) assert mapper.STAGE_RANGES["finalization"] == (90, 100) assert mapper.STAGE_RANGES["completed"] == (100, 100) - + # Upload-specific stages assert mapper.STAGE_RANGES["reading"] == (0, 5) assert mapper.STAGE_RANGES["text_extraction"] == (5, 10) @@ -123,138 +133,167 @@ def test_stage_ranges(self): # Note: source_creation is shared between crawl and upload operations at (20, 25) assert mapper.STAGE_RANGES["summarizing"] == (25, 35) assert mapper.STAGE_RANGES["storing"] == (35, 100) - + def test_calculate_stage_progress(self): """Test calculating percentage within a stage""" mapper = ProgressMapper() - + # 5 out of 10 = 50% progress = mapper.calculate_stage_progress(5, 10) assert progress == 50.0 - + # 0 out of 10 = 0% progress = mapper.calculate_stage_progress(0, 10) assert progress == 0.0 - + # 10 out of 10 = 100% progress = mapper.calculate_stage_progress(10, 10) assert progress == 100.0 - + # Handle division by zero progress = mapper.calculate_stage_progress(5, 0) assert progress == 0.0 - + def test_map_batch_progress(self): """Test batch progress mapping""" mapper = ProgressMapper() - + # Batch 1 of 5 in document_storage stage progress = mapper.map_batch_progress("document_storage", 1, 5) assert progress == 25 # Start of document_storage range (25-40) - + # Batch 3 of 5 progress = mapper.map_batch_progress("document_storage", 3, 5) assert progress == 31 # 40% through 25-40 range - + # Batch 5 of 5 progress = mapper.map_batch_progress("document_storage", 5, 5) assert progress == 37 # 80% through 25-40 range - + def test_map_with_substage(self): """Test mapping with substage information""" mapper = ProgressMapper() - + # Currently just uses main stage progress = mapper.map_with_substage("document_storage", "embeddings", 50) assert progress == 32 # 50% of 25-40 range = 32.5 -> 32 - + def test_reset(self): """Test resetting the mapper""" mapper = ProgressMapper() - + # Set some progress mapper.map_progress("document_storage", 50) assert mapper.last_overall_progress == 32 # 25 + (50% of 15) = 32.5 -> 32 assert mapper.current_stage == "document_storage" - + # Reset mapper.reset() assert mapper.last_overall_progress == 0 assert mapper.current_stage == "starting" - + def test_get_current_stage(self): """Test getting current stage""" mapper = ProgressMapper() - + assert mapper.get_current_stage() == "starting" - + mapper.map_progress("crawling", 50) assert mapper.get_current_stage() == "crawling" - + mapper.map_progress("code_extraction", 50) assert mapper.get_current_stage() == "code_extraction" - + def test_get_current_progress(self): """Test getting current progress""" mapper = ProgressMapper() - + assert mapper.get_current_progress() == 0 - + mapper.map_progress("crawling", 50) - assert mapper.get_current_progress() == 9 # 3 + (50% of 12) = 9 - + assert mapper.get_current_progress() == 10 # 4 + (50% of 11) = 9.5 -> 10 + mapper.map_progress("code_extraction", 50) assert mapper.get_current_progress() == 65 # 40 + (50% of 50) = 65 - + def test_get_stage_range(self): """Test getting stage range""" mapper = ProgressMapper() - + assert mapper.get_stage_range("starting") == (0, 1) + assert mapper.get_stage_range("discovery") == (3, 4) # Test discovery stage assert mapper.get_stage_range("code_extraction") == (40, 90) assert mapper.get_stage_range("unknown") == (0, 100) # Default range - - def test_realistic_crawl_sequence(self): - """Test a realistic crawl progress sequence""" + + def test_realistic_crawl_sequence_with_discovery(self): + """Test a realistic crawl progress sequence including discovery""" mapper = ProgressMapper() - + # Starting assert mapper.map_progress("starting", 0) == 0 assert mapper.map_progress("starting", 100) == 1 - + # Analyzing assert mapper.map_progress("analyzing", 0) == 1 assert mapper.map_progress("analyzing", 100) == 3 - + + # Discovery (NEW) + assert mapper.map_progress("discovery", 0) == 3 + assert mapper.map_progress("discovery", 50) == 4 # 3 + (50% of 1) = 3.5 -> 4 (rounds up) + assert mapper.map_progress("discovery", 100) == 4 + # Crawling - assert mapper.map_progress("crawling", 0) == 3 - assert mapper.map_progress("crawling", 33) == 7 # 3 + (33% of 12) = 6.96 -> 7 - assert mapper.map_progress("crawling", 66) == 11 # 3 + (66% of 12) = 10.92 -> 11 + assert mapper.map_progress("crawling", 0) == 4 + assert mapper.map_progress("crawling", 33) == 8 # 4 + (33% of 11) = 7.63 -> 8 (rounds up) + progress_crawl_66 = mapper.map_progress("crawling", 66) + assert progress_crawl_66 in [11, 12] # 4 + (66% of 11) = 11.26, could round to 11 or 12 assert mapper.map_progress("crawling", 100) == 15 - + # Processing assert mapper.map_progress("processing", 0) == 15 assert mapper.map_progress("processing", 100) == 20 - + # Source creation assert mapper.map_progress("source_creation", 0) == 20 assert mapper.map_progress("source_creation", 100) == 25 - + # Document storage assert mapper.map_progress("document_storage", 0) == 25 assert mapper.map_progress("document_storage", 50) == 32 # 25 + (50% of 15) = 32.5 -> 32 assert mapper.map_progress("document_storage", 100) == 40 - + # Code extraction (longest phase) assert mapper.map_progress("code_extraction", 0) == 40 - assert mapper.map_progress("code_extraction", 25) == 52 # 40 + (25% of 50) = 52.5 -> 52 + progress_25 = mapper.map_progress("code_extraction", 25) + assert progress_25 in [52, 53] # 40 + (25% of 50) = 52.5, banker's rounding rounds to 52 (even) assert mapper.map_progress("code_extraction", 50) == 65 # 40 + (50% of 50) = 65 - assert mapper.map_progress("code_extraction", 75) == 78 # 40 + (75% of 50) = 77.5 -> 78 + progress_75 = mapper.map_progress("code_extraction", 75) + assert progress_75 == 78 # 40 + (75% of 50) = 77.5 -> 78 (rounds to even per banker's rounding) assert mapper.map_progress("code_extraction", 100) == 90 - + # Finalization assert mapper.map_progress("finalization", 0) == 90 assert mapper.map_progress("finalization", 100) == 100 - + # Completed - assert mapper.map_progress("completed", 0) == 100 \ No newline at end of file + assert mapper.map_progress("completed", 0) == 100 + + def test_aliases_work_correctly(self): + """Test that stage aliases work correctly""" + mapper = ProgressMapper() + + # Test code_storage alias for code_extraction + progress1 = mapper.map_progress("code_extraction", 50) + mapper2 = ProgressMapper() + progress2 = mapper2.map_progress("code_storage", 50) + assert progress1 == progress2 + + # Test extracting alias for code_extraction + mapper3 = ProgressMapper() + progress3 = mapper3.map_progress("extracting", 50) + assert progress1 == progress3 + + # Test complete alias for completed + mapper4 = ProgressMapper() + progress4 = mapper4.map_progress("complete", 0) + assert progress4 == 100 \ No newline at end of file diff --git a/python/tests/test_crawling_service_subdomain.py b/python/tests/test_crawling_service_subdomain.py new file mode 100644 index 0000000000..543423c8df --- /dev/null +++ b/python/tests/test_crawling_service_subdomain.py @@ -0,0 +1,152 @@ +"""Unit tests for CrawlingService subdomain checking functionality.""" +import pytest +from src.server.services.crawling.crawling_service import CrawlingService + + +class TestCrawlingServiceSubdomain: + """Test suite for CrawlingService subdomain checking methods.""" + + @pytest.fixture + def service(self): + """Create a CrawlingService instance for testing.""" + # Create service without crawler or supabase for testing domain checking + return CrawlingService(crawler=None, supabase_client=None) + + def test_is_same_domain_or_subdomain_exact_match(self, service): + """Test exact domain matches.""" + # Same domain should match + assert service._is_same_domain_or_subdomain( + "https://supabase.com/docs", + "https://supabase.com" + ) is True + + assert service._is_same_domain_or_subdomain( + "https://supabase.com/path/to/page", + "https://supabase.com" + ) is True + + def test_is_same_domain_or_subdomain_subdomains(self, service): + """Test subdomain matching.""" + # Subdomain should match + assert service._is_same_domain_or_subdomain( + "https://docs.supabase.com/llms.txt", + "https://supabase.com" + ) is True + + assert service._is_same_domain_or_subdomain( + "https://api.supabase.com/v1/endpoint", + "https://supabase.com" + ) is True + + # Multiple subdomain levels + assert service._is_same_domain_or_subdomain( + "https://dev.api.supabase.com/test", + "https://supabase.com" + ) is True + + def test_is_same_domain_or_subdomain_different_domains(self, service): + """Test that different domains are rejected.""" + # Different domain should not match + assert service._is_same_domain_or_subdomain( + "https://external.com/llms.txt", + "https://supabase.com" + ) is False + + assert service._is_same_domain_or_subdomain( + "https://docs.other-site.com", + "https://supabase.com" + ) is False + + # Similar but different domains + assert service._is_same_domain_or_subdomain( + "https://supabase.org", + "https://supabase.com" + ) is False + + def test_is_same_domain_or_subdomain_protocols(self, service): + """Test that protocol differences don't affect matching.""" + # Different protocols should still match + assert service._is_same_domain_or_subdomain( + "http://supabase.com/docs", + "https://supabase.com" + ) is True + + assert service._is_same_domain_or_subdomain( + "https://docs.supabase.com", + "http://supabase.com" + ) is True + + def test_is_same_domain_or_subdomain_ports(self, service): + """Test handling of port numbers.""" + # Same root domain with different ports should match + assert service._is_same_domain_or_subdomain( + "https://supabase.com:8080/api", + "https://supabase.com" + ) is True + + assert service._is_same_domain_or_subdomain( + "http://localhost:3000/dev", + "http://localhost:8080" + ) is True + + def test_is_same_domain_or_subdomain_edge_cases(self, service): + """Test edge cases and error handling.""" + # Empty or malformed URLs should return False + assert service._is_same_domain_or_subdomain( + "", + "https://supabase.com" + ) is False + + assert service._is_same_domain_or_subdomain( + "https://supabase.com", + "" + ) is False + + assert service._is_same_domain_or_subdomain( + "not-a-url", + "https://supabase.com" + ) is False + + def test_is_same_domain_or_subdomain_real_world_examples(self, service): + """Test with real-world examples.""" + # GitHub examples + assert service._is_same_domain_or_subdomain( + "https://api.github.com/repos", + "https://github.com" + ) is True + + assert service._is_same_domain_or_subdomain( + "https://raw.githubusercontent.com/owner/repo", + "https://github.com" + ) is False # githubusercontent.com is different root domain + + # Documentation sites + assert service._is_same_domain_or_subdomain( + "https://docs.python.org/3/library", + "https://python.org" + ) is True + + assert service._is_same_domain_or_subdomain( + "https://api.stripe.com/v1", + "https://stripe.com" + ) is True + + def test_is_same_domain_backward_compatibility(self, service): + """Test that _is_same_domain still works correctly for exact matches.""" + # Exact domain match should work + assert service._is_same_domain( + "https://supabase.com/docs", + "https://supabase.com" + ) is True + + # Subdomain should NOT match with _is_same_domain (only with _is_same_domain_or_subdomain) + assert service._is_same_domain( + "https://docs.supabase.com/llms.txt", + "https://supabase.com" + ) is False + + # Different domain should not match + assert service._is_same_domain( + "https://external.com/llms.txt", + "https://supabase.com" + ) is False diff --git a/python/tests/test_discovery_service.py b/python/tests/test_discovery_service.py new file mode 100644 index 0000000000..b7b41a9561 --- /dev/null +++ b/python/tests/test_discovery_service.py @@ -0,0 +1,353 @@ +"""Unit tests for DiscoveryService class.""" +import socket +from unittest.mock import Mock, patch + +from src.server.services.crawling.discovery_service import DiscoveryService + + +def create_mock_dns_response(): + """Create mock DNS response for safe public IPs.""" + # Return a safe public IP for testing + return [ + (socket.AF_INET, socket.SOCK_STREAM, 6, '', ('93.184.216.34', 0)) # example.com's actual IP + ] + + +def create_mock_response(status_code: int, text: str = "", url: str = "https://example.com") -> Mock: + """Create a mock response object that supports streaming API.""" + response = Mock() + response.status_code = status_code + response.text = text + response.encoding = 'utf-8' + response.history = [] # Empty list for no redirects + response.url = url # Mock URL for redirect checks (must be string, not Mock) + + # Mock iter_content to yield text in chunks as bytes + text_bytes = text.encode('utf-8') + chunk_size = 8192 + chunks = [text_bytes[i:i+chunk_size] for i in range(0, len(text_bytes), chunk_size)] + if not chunks: + chunks = [b''] # Ensure at least one empty chunk + response.iter_content = Mock(return_value=iter(chunks)) + + # Mock close method + response.close = Mock() + + return response + + +class TestDiscoveryService: + """Test suite for DiscoveryService class.""" + + @patch('socket.getaddrinfo', return_value=create_mock_dns_response()) + @patch('requests.Session') + @patch('requests.get') + def test_discover_files_basic(self, mock_get, mock_session, mock_dns): + """Test main discovery method returns single best file.""" + service = DiscoveryService() + base_url = "https://example.com" + + # Mock robots.txt response (no sitemaps) + robots_response = create_mock_response(200, "User-agent: *\nDisallow: /admin/") + + # Mock file existence - llms-full.txt doesn't exist, but llms.txt does + def mock_get_side_effect(url, **kwargs): + if url.endswith('robots.txt'): + return robots_response + elif url.endswith('llms-full.txt'): + return create_mock_response(404) # Highest priority doesn't exist + elif url.endswith('llms.txt'): + return create_mock_response(200) # Second priority exists + else: + return create_mock_response(404) + + mock_get.side_effect = mock_get_side_effect + mock_session.return_value.get.side_effect = mock_get_side_effect + + result = service.discover_files(base_url) + + # Should return single URL string (not dict, not list) + assert isinstance(result, str) + assert result == 'https://example.com/llms.txt' + + @patch('socket.getaddrinfo', return_value=create_mock_dns_response()) + @patch('requests.Session') + @patch('requests.get') + def test_discover_files_no_files_found(self, mock_get, mock_session, mock_dns): + """Test discovery when no files are found.""" + service = DiscoveryService() + base_url = "https://example.com" + + # Mock all HTTP requests to return 404 + mock_get.return_value = create_mock_response(404) + mock_session.return_value.get.return_value = create_mock_response(404) + + result = service.discover_files(base_url) + + # Should return None when no files found + assert result is None + + @patch('socket.getaddrinfo', return_value=create_mock_dns_response()) + @patch('requests.Session') + @patch('requests.get') + def test_discover_files_priority_order(self, mock_get, mock_session, mock_dns): + """Test that discovery follows the correct priority order.""" + service = DiscoveryService() + base_url = "https://example.com" + + # Mock robots.txt response (no sitemaps declared) + robots_response = create_mock_response(200, "User-agent: *\nDisallow: /admin/") + + # Mock file existence - both sitemap.xml and llms.txt exist, but llms.txt has higher priority + def mock_get_side_effect(url, **kwargs): + if url.endswith('robots.txt'): + return robots_response + elif url.endswith('llms.txt') or url.endswith('sitemap.xml'): + return create_mock_response(200) # Both exist + else: + return create_mock_response(404) + + mock_get.side_effect = mock_get_side_effect + mock_session.return_value.get.side_effect = mock_get_side_effect + + result = service.discover_files(base_url) + + # Should return llms.txt since it has higher priority than sitemap.xml + assert result == 'https://example.com/llms.txt' + + @patch('socket.getaddrinfo', return_value=create_mock_dns_response()) + @patch('requests.Session') + @patch('requests.get') + def test_discover_files_robots_sitemap_priority(self, mock_get, mock_session, mock_dns): + """Test that llms files have priority over robots.txt sitemap declarations.""" + service = DiscoveryService() + base_url = "https://example.com" + + # Mock robots.txt response WITH sitemap declaration + robots_response = create_mock_response(200, "User-agent: *\nSitemap: https://example.com/declared-sitemap.xml") + + # Mock other files also exist (both llms and sitemap files) + def mock_get_side_effect(url, **kwargs): + if url.endswith('robots.txt'): + return robots_response + elif 'llms' in url or 'sitemap' in url: + return create_mock_response(200) + else: + return create_mock_response(404) + + mock_get.side_effect = mock_get_side_effect + mock_session.return_value.get.side_effect = mock_get_side_effect + + result = service.discover_files(base_url) + + # Should return llms.txt (highest priority llms file) since llms files have priority over sitemaps + # even when sitemaps are declared in robots.txt + assert result == 'https://example.com/llms.txt' + + @patch('socket.getaddrinfo', return_value=create_mock_dns_response()) + @patch('requests.Session') + @patch('requests.get') + def test_discover_files_subdirectory_fallback(self, mock_get, mock_session, mock_dns): + """Test discovery falls back to subdirectories for llms files.""" + service = DiscoveryService() + base_url = "https://example.com" + + # Mock robots.txt response (no sitemaps declared) + robots_response = create_mock_response(200, "User-agent: *\nDisallow: /admin/") + + # Mock file existence - no root llms files, but static/llms.txt exists + def mock_get_side_effect(url, **kwargs): + if url.endswith('robots.txt'): + return robots_response + elif '/static/llms.txt' in url: + return create_mock_response(200) # Found in subdirectory + else: + return create_mock_response(404) + + mock_get.side_effect = mock_get_side_effect + mock_session.return_value.get.side_effect = mock_get_side_effect + + result = service.discover_files(base_url) + + # Should find the file in static subdirectory + assert result == 'https://example.com/static/llms.txt' + + @patch('socket.getaddrinfo', return_value=create_mock_dns_response()) + @patch('requests.Session') + @patch('requests.get') + def test_check_url_exists(self, mock_get, mock_session, mock_dns): + """Test URL existence checking.""" + service = DiscoveryService() + + # Test successful response + mock_get.return_value = create_mock_response(200) + mock_session.return_value.get.return_value = create_mock_response(200) + assert service._check_url_exists("https://example.com/exists") is True + + # Test 404 response + mock_get.return_value = create_mock_response(404) + mock_session.return_value.get.return_value = create_mock_response(404) + assert service._check_url_exists("https://example.com/not-found") is False + + # Test network error + mock_get.side_effect = Exception + mock_session.return_value.get.side_effect = Exception("Network error") + assert service._check_url_exists("https://example.com/error") is False + + @patch('socket.getaddrinfo', return_value=create_mock_dns_response()) + @patch('requests.Session') + @patch('requests.get') + def test_parse_robots_txt_with_sitemap(self, mock_get, mock_session, mock_dns): + """Test robots.txt parsing with sitemap directives.""" + service = DiscoveryService() + + # Mock successful robots.txt response + robots_text = """User-agent: * +Disallow: /admin/ +Sitemap: https://example.com/sitemap.xml +Sitemap: https://example.com/sitemap-news.xml""" + mock_get.return_value = create_mock_response(200, robots_text) + + result = service._parse_robots_txt("https://example.com") + + assert len(result) == 2 + assert "https://example.com/sitemap.xml" in result + assert "https://example.com/sitemap-news.xml" in result + mock_get.assert_called_once_with("https://example.com/robots.txt", timeout=30, stream=True, verify=True, headers={'User-Agent': 'Archon-Discovery/1.0 (SSRF-Protected)'}) + + @patch('socket.getaddrinfo', return_value=create_mock_dns_response()) + @patch('requests.Session') + @patch('requests.get') + def test_parse_robots_txt_no_sitemap(self, mock_get, mock_session, mock_dns): + """Test robots.txt parsing without sitemap directives.""" + service = DiscoveryService() + + # Mock robots.txt without sitemaps + robots_text = """User-agent: * +Disallow: /admin/ +Allow: /public/""" + mock_get.return_value = create_mock_response(200, robots_text) + + result = service._parse_robots_txt("https://example.com") + + assert len(result) == 0 + mock_get.assert_called_once_with("https://example.com/robots.txt", timeout=30, stream=True, verify=True, headers={'User-Agent': 'Archon-Discovery/1.0 (SSRF-Protected)'}) + + @patch('socket.getaddrinfo', return_value=create_mock_dns_response()) + @patch('requests.Session') + @patch('requests.get') + def test_parse_html_meta_tags(self, mock_get, mock_session, mock_dns): + """Test HTML meta tag parsing for sitemaps.""" + service = DiscoveryService() + + # Mock HTML with sitemap references + html_content = """ + + + + + + Content here + + """ + mock_get.return_value = create_mock_response(200, html_content) + + result = service._parse_html_meta_tags("https://example.com") + + # Should find sitemaps from both link and meta tags + assert len(result) >= 1 + assert any('sitemap' in url.lower() for url in result) + mock_get.assert_called_once_with("https://example.com", timeout=30, stream=True, verify=True, headers={'User-Agent': 'Archon-Discovery/1.0 (SSRF-Protected)'}) + + @patch('socket.getaddrinfo', return_value=create_mock_dns_response()) + @patch('requests.Session') + @patch('requests.get') + def test_discovery_priority_behavior(self, mock_get, mock_session, mock_dns): + """Test that discovery returns highest-priority file when multiple files exist.""" + service = DiscoveryService() + base_url = "https://example.com" + + # Mock robots.txt response (no sitemaps declared) + robots_response = create_mock_response(200, "User-agent: *\nDisallow: /admin/") + + # Scenario 1: All files exist - should return llms.txt (highest priority) + def mock_all_exist(url, **kwargs): + if url.endswith('robots.txt'): + return robots_response + elif any(file in url for file in ['llms.txt', 'llms-full.txt', 'sitemap.xml']): + return create_mock_response(200) + else: + return create_mock_response(404) + + mock_get.side_effect = mock_all_exist + mock_session.return_value.get.side_effect = mock_all_exist + result = service.discover_files(base_url) + assert result == 'https://example.com/llms.txt', "Should return llms.txt when all files exist (highest priority)" + + # Scenario 2: llms.txt missing, others exist - should return llms-full.txt + def mock_without_txt(url, **kwargs): + if url.endswith('robots.txt'): + return robots_response + elif url.endswith('llms.txt'): + return create_mock_response(404) + elif any(file in url for file in ['llms-full.txt', 'sitemap.xml']): + return create_mock_response(200) + else: + return create_mock_response(404) + + mock_get.side_effect = mock_without_txt + mock_session.return_value.get.side_effect = mock_without_txt + result = service.discover_files(base_url) + assert result == 'https://example.com/llms-full.txt', "Should return llms-full.txt when llms.txt is missing" + + # Scenario 3: Only sitemap files exist - should return sitemap.xml + def mock_only_sitemaps(url, **kwargs): + if url.endswith('robots.txt'): + return robots_response + elif any(file in url for file in ['llms.txt', 'llms-full.txt']): + return create_mock_response(404) + elif url.endswith('sitemap.xml'): + return create_mock_response(200) + else: + return create_mock_response(404) + + mock_get.side_effect = mock_only_sitemaps + mock_session.return_value.get.side_effect = mock_only_sitemaps + result = service.discover_files(base_url) + assert result == 'https://example.com/sitemap.xml', "Should return sitemap.xml when llms files are missing" + + # Scenario 4: llms files have priority over sitemap files + def mock_llms_and_sitemap(url, **kwargs): + if url.endswith('robots.txt'): + return robots_response + elif url.endswith('llms.txt') or url.endswith('sitemap.xml'): + return create_mock_response(200) + else: + return create_mock_response(404) + + mock_get.side_effect = mock_llms_and_sitemap + mock_session.return_value.get.side_effect = mock_llms_and_sitemap + result = service.discover_files(base_url) + assert result == 'https://example.com/llms.txt', "Should prefer llms.txt over sitemap.xml" + + @patch('socket.getaddrinfo', return_value=create_mock_dns_response()) + @patch('requests.Session') + @patch('requests.get') + def test_network_error_handling(self, mock_get, mock_session, mock_dns): + """Test error scenarios with network failures.""" + service = DiscoveryService() + + # Mock network error + mock_get.side_effect = Exception("Network error") + mock_session.return_value.get.side_effect = Exception("Network error") + + # Should not raise exception, but return None + result = service.discover_files("https://example.com") + assert result is None + + # Individual methods should also handle errors gracefully + result = service._parse_robots_txt("https://example.com") + assert result == [] + + result = service._parse_html_meta_tags("https://example.com") + assert result == [] diff --git a/python/tests/test_llms_txt_link_following.py b/python/tests/test_llms_txt_link_following.py new file mode 100644 index 0000000000..6cc43a5904 --- /dev/null +++ b/python/tests/test_llms_txt_link_following.py @@ -0,0 +1,217 @@ +"""Integration tests for llms.txt link following functionality.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from src.server.services.crawling.crawling_service import CrawlingService + + +class TestLlmsTxtLinkFollowing: + """Test suite for llms.txt link following feature.""" + + @pytest.fixture + def service(self): + """Create a CrawlingService instance for testing.""" + return CrawlingService(crawler=None, supabase_client=None) + + @pytest.fixture + def supabase_llms_content(self): + """Return the actual Supabase llms.txt content.""" + return """# Supabase Docs + +- [Supabase Guides](https://supabase.com/llms/guides.txt) +- [Supabase Reference (JavaScript)](https://supabase.com/llms/js.txt) +- [Supabase Reference (Dart)](https://supabase.com/llms/dart.txt) +- [Supabase Reference (Swift)](https://supabase.com/llms/swift.txt) +- [Supabase Reference (Kotlin)](https://supabase.com/llms/kotlin.txt) +- [Supabase Reference (Python)](https://supabase.com/llms/python.txt) +- [Supabase Reference (C#)](https://supabase.com/llms/csharp.txt) +- [Supabase CLI Reference](https://supabase.com/llms/cli.txt) +""" + + def test_extract_links_from_supabase_llms_txt(self, service, supabase_llms_content): + """Test that links are correctly extracted from Supabase llms.txt.""" + url = "https://supabase.com/docs/llms.txt" + + extracted_links = service.url_handler.extract_markdown_links_with_text( + supabase_llms_content, url + ) + + # Should extract 8 links + assert len(extracted_links) == 8 + + # Verify all extracted links + expected_links = [ + "https://supabase.com/llms/guides.txt", + "https://supabase.com/llms/js.txt", + "https://supabase.com/llms/dart.txt", + "https://supabase.com/llms/swift.txt", + "https://supabase.com/llms/kotlin.txt", + "https://supabase.com/llms/python.txt", + "https://supabase.com/llms/csharp.txt", + "https://supabase.com/llms/cli.txt", + ] + + extracted_urls = [link for link, _ in extracted_links] + assert extracted_urls == expected_links + + def test_all_links_are_llms_variants(self, service, supabase_llms_content): + """Test that all extracted links are recognized as llms.txt variants.""" + url = "https://supabase.com/docs/llms.txt" + + extracted_links = service.url_handler.extract_markdown_links_with_text( + supabase_llms_content, url + ) + + # All links should be recognized as llms variants + for link, _ in extracted_links: + is_llms = service.url_handler.is_llms_variant(link) + assert is_llms, f"Link {link} should be recognized as llms.txt variant" + + def test_all_links_are_same_domain(self, service, supabase_llms_content): + """Test that all extracted links are from the same domain.""" + url = "https://supabase.com/docs/llms.txt" + original_domain = "https://supabase.com" + + extracted_links = service.url_handler.extract_markdown_links_with_text( + supabase_llms_content, url + ) + + # All links should be from the same domain + for link, _ in extracted_links: + is_same = service._is_same_domain_or_subdomain(link, original_domain) + assert is_same, f"Link {link} should match domain {original_domain}" + + def test_filter_llms_links_from_supabase(self, service, supabase_llms_content): + """Test the complete filtering logic for Supabase llms.txt.""" + url = "https://supabase.com/docs/llms.txt" + original_domain = "https://supabase.com" + + # Extract all links + extracted_links = service.url_handler.extract_markdown_links_with_text( + supabase_llms_content, url + ) + + # Filter for llms.txt files on same domain (mimics actual code) + llms_links = [] + for link, text in extracted_links: + if service.url_handler.is_llms_variant(link): + if service._is_same_domain_or_subdomain(link, original_domain): + llms_links.append((link, text)) + + # Should have all 8 links + assert len(llms_links) == 8, f"Expected 8 llms links, got {len(llms_links)}" + + @pytest.mark.asyncio + async def test_llms_txt_link_following_integration(self, service, supabase_llms_content): + """Integration test for the complete llms.txt link following flow.""" + url = "https://supabase.com/docs/llms.txt" + + # Mock the crawl_batch_with_progress to verify it's called with correct URLs + mock_batch_results = [ + {'url': f'https://supabase.com/llms/{name}.txt', 'markdown': f'# {name}', 'title': f'{name}'} + for name in ['guides', 'js', 'dart', 'swift', 'kotlin', 'python', 'csharp', 'cli'] + ] + + service.crawl_batch_with_progress = AsyncMock(return_value=mock_batch_results) + service.crawl_markdown_file = AsyncMock(return_value=[{ + 'url': url, + 'markdown': supabase_llms_content, + 'title': 'Supabase Docs' + }]) + + # Create progress tracker mock + service.progress_tracker = MagicMock() + service.progress_tracker.update = AsyncMock() + + # Simulate the request that would come from orchestration + request = { + "is_discovery_target": True, + "original_domain": "https://supabase.com", + "max_concurrent": 5 + } + + # Call the actual crawl method + crawl_results, crawl_type = await service._crawl_by_url_type(url, request) + + # Verify batch crawl was called with the 8 llms.txt URLs + service.crawl_batch_with_progress.assert_called_once() + call_args = service.crawl_batch_with_progress.call_args + crawled_urls = call_args[0][0] # First positional argument + + assert len(crawled_urls) == 8, f"Should crawl 8 linked files, got {len(crawled_urls)}" + + expected_urls = [ + "https://supabase.com/llms/guides.txt", + "https://supabase.com/llms/js.txt", + "https://supabase.com/llms/dart.txt", + "https://supabase.com/llms/swift.txt", + "https://supabase.com/llms/kotlin.txt", + "https://supabase.com/llms/python.txt", + "https://supabase.com/llms/csharp.txt", + "https://supabase.com/llms/cli.txt", + ] + + assert set(crawled_urls) == set(expected_urls) + + # Verify total results include main file + linked pages + assert len(crawl_results) == 9, f"Should have 9 total pages (1 main + 8 linked), got {len(crawl_results)}" + + # Verify crawl type + assert crawl_type == "llms_txt_with_linked_pages" + + def test_external_llms_links_are_filtered(self, service): + """Test that external domain llms.txt links are filtered out.""" + content = """# Test llms.txt + +- [Internal Link](https://supabase.com/llms/internal.txt) +- [External Link](https://external.com/llms/external.txt) +- [Another Internal](https://docs.supabase.com/llms/docs.txt) +""" + url = "https://supabase.com/llms.txt" + original_domain = "https://supabase.com" + + extracted_links = service.url_handler.extract_markdown_links_with_text(content, url) + + # Filter for same-domain llms links + llms_links = [] + for link, text in extracted_links: + if service.url_handler.is_llms_variant(link): + if service._is_same_domain_or_subdomain(link, original_domain): + llms_links.append((link, text)) + + # Should only have 2 links (internal and subdomain), external filtered out + assert len(llms_links) == 2 + + urls = [link for link, _ in llms_links] + assert "https://supabase.com/llms/internal.txt" in urls + assert "https://docs.supabase.com/llms/docs.txt" in urls + assert "https://external.com/llms/external.txt" not in urls + + def test_non_llms_links_are_filtered(self, service): + """Test that non-llms.txt links are filtered out.""" + content = """# Test llms.txt + +- [LLMs Link](https://supabase.com/llms/guide.txt) +- [Regular Doc](https://supabase.com/docs/guide) +- [PDF File](https://supabase.com/docs/guide.pdf) +- [Another LLMs](https://supabase.com/llms/api.txt) +""" + url = "https://supabase.com/llms.txt" + original_domain = "https://supabase.com" + + extracted_links = service.url_handler.extract_markdown_links_with_text(content, url) + + # Filter for llms links only + llms_links = [] + for link, text in extracted_links: + if service.url_handler.is_llms_variant(link): + if service._is_same_domain_or_subdomain(link, original_domain): + llms_links.append((link, text)) + + # Should only have 2 llms.txt links + assert len(llms_links) == 2 + + urls = [link for link, _ in llms_links] + assert "https://supabase.com/llms/guide.txt" in urls + assert "https://supabase.com/llms/api.txt" in urls + assert "https://supabase.com/docs/guide" not in urls + assert "https://supabase.com/docs/guide.pdf" not in urls diff --git a/python/tests/test_url_handler.py b/python/tests/test_url_handler.py index 1310bd8741..e268bd500b 100644 --- a/python/tests/test_url_handler.py +++ b/python/tests/test_url_handler.py @@ -122,4 +122,120 @@ def test_transform_github_url(self): # Should not transform non-GitHub URLs other = "https://example.com/file" - assert handler.transform_github_url(other) == other \ No newline at end of file + assert handler.transform_github_url(other) == other + + def test_is_robots_txt(self): + """Test robots.txt detection.""" + handler = URLHandler() + + # Standard robots.txt URLs + assert handler.is_robots_txt("https://example.com/robots.txt") is True + assert handler.is_robots_txt("http://example.com/robots.txt") is True + assert handler.is_robots_txt("https://sub.example.com/robots.txt") is True + + # Case sensitivity + assert handler.is_robots_txt("https://example.com/ROBOTS.TXT") is True + assert handler.is_robots_txt("https://example.com/Robots.Txt") is True + + # With query parameters (should still be detected) + assert handler.is_robots_txt("https://example.com/robots.txt?v=1") is True + assert handler.is_robots_txt("https://example.com/robots.txt#section") is True + + # Not robots.txt files + assert handler.is_robots_txt("https://example.com/robots") is False + assert handler.is_robots_txt("https://example.com/robots.html") is False + assert handler.is_robots_txt("https://example.com/some-robots.txt") is False + assert handler.is_robots_txt("https://example.com/path/robots.txt") is False + assert handler.is_robots_txt("https://example.com/") is False + + # Edge case: malformed URL should not crash + assert handler.is_robots_txt("not-a-url") is False + + def test_is_llms_variant(self): + """Test llms file variant detection.""" + handler = URLHandler() + + # Standard llms.txt spec variants (only txt files) + assert handler.is_llms_variant("https://example.com/llms.txt") is True + assert handler.is_llms_variant("https://example.com/llms-full.txt") is True + + # Case sensitivity + assert handler.is_llms_variant("https://example.com/LLMS.TXT") is True + assert handler.is_llms_variant("https://example.com/LLMS-FULL.TXT") is True + + # With paths (should still detect) + assert handler.is_llms_variant("https://example.com/docs/llms.txt") is True + assert handler.is_llms_variant("https://example.com/public/llms-full.txt") is True + + # With query parameters + assert handler.is_llms_variant("https://example.com/llms.txt?version=1") is True + assert handler.is_llms_variant("https://example.com/llms-full.txt#section") is True + + # Not llms files + assert handler.is_llms_variant("https://example.com/llms") is False + assert handler.is_llms_variant("https://example.com/llms.html") is False + assert handler.is_llms_variant("https://example.com/my-llms.txt") is False + assert handler.is_llms_variant("https://example.com/llms-guide.txt") is False + assert handler.is_llms_variant("https://example.com/readme.txt") is False + + # Edge case: malformed URL should not crash + assert handler.is_llms_variant("not-a-url") is False + + def test_is_well_known_file(self): + """Test .well-known file detection.""" + handler = URLHandler() + + # Standard .well-known files + assert handler.is_well_known_file("https://example.com/.well-known/ai.txt") is True + assert handler.is_well_known_file("https://example.com/.well-known/security.txt") is True + assert handler.is_well_known_file("https://example.com/.well-known/change-password") is True + + # Case sensitivity - RFC 8615 requires lowercase .well-known + assert handler.is_well_known_file("https://example.com/.WELL-KNOWN/ai.txt") is False + assert handler.is_well_known_file("https://example.com/.Well-Known/ai.txt") is False + + # With query parameters + assert handler.is_well_known_file("https://example.com/.well-known/ai.txt?v=1") is True + assert handler.is_well_known_file("https://example.com/.well-known/ai.txt#top") is True + + # Not .well-known files + assert handler.is_well_known_file("https://example.com/well-known/ai.txt") is False + assert handler.is_well_known_file("https://example.com/.wellknown/ai.txt") is False + assert handler.is_well_known_file("https://example.com/docs/.well-known/ai.txt") is False + assert handler.is_well_known_file("https://example.com/ai.txt") is False + assert handler.is_well_known_file("https://example.com/") is False + + # Edge case: malformed URL should not crash + assert handler.is_well_known_file("not-a-url") is False + + def test_get_base_url(self): + """Test base URL extraction.""" + handler = URLHandler() + + # Standard URLs + assert handler.get_base_url("https://example.com") == "https://example.com" + assert handler.get_base_url("https://example.com/") == "https://example.com" + assert handler.get_base_url("https://example.com/path/to/page") == "https://example.com" + assert handler.get_base_url("https://example.com/path/to/page?query=1") == "https://example.com" + assert handler.get_base_url("https://example.com/path/to/page#fragment") == "https://example.com" + + # HTTP vs HTTPS + assert handler.get_base_url("http://example.com/path") == "http://example.com" + assert handler.get_base_url("https://example.com/path") == "https://example.com" + + # Subdomains and ports + assert handler.get_base_url("https://api.example.com/v1/users") == "https://api.example.com" + assert handler.get_base_url("https://example.com:8080/api") == "https://example.com:8080" + assert handler.get_base_url("http://localhost:3000/dev") == "http://localhost:3000" + + # Complex cases + assert handler.get_base_url("https://user:pass@example.com/path") == "https://user:pass@example.com" + + # Edge cases - malformed URLs should return original + assert handler.get_base_url("not-a-url") == "not-a-url" + assert handler.get_base_url("") == "" + assert handler.get_base_url("ftp://example.com/file") == "ftp://example.com" + + # Missing scheme or netloc + assert handler.get_base_url("//example.com/path") == "//example.com/path" # Should return original + assert handler.get_base_url("/path/to/resource") == "/path/to/resource" # Should return original \ No newline at end of file diff --git a/python/uv.lock b/python/uv.lock index 274564d2d0..f8f82b0185 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -247,6 +247,7 @@ server = [ { name = "python-multipart" }, { name = "slowapi" }, { name = "supabase" }, + { name = "tldextract" }, { name = "uvicorn" }, { name = "watchfiles" }, ] @@ -342,6 +343,7 @@ server = [ { name = "python-multipart", specifier = ">=0.0.20" }, { name = "slowapi", specifier = ">=0.1.9" }, { name = "supabase", specifier = "==2.15.1" }, + { name = "tldextract", specifier = ">=5.0.0" }, { name = "uvicorn", specifier = ">=0.24.0" }, { name = "watchfiles", specifier = ">=0.18" }, ] @@ -2601,6 +2603,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, ] +[[package]] +name = "requests-file" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fe/5e/2aca791207e542a16a8cc91fd0e19f5c26f4dff030ee3062deb5606f84ae/requests_file-3.0.0.tar.gz", hash = "sha256:68789589cfde7098e8933fe3e69bbd864f7f0c22f118937b424d94d0e1b7760f", size = 6897 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/85/689c218feb21a66919bd667969d4ed60a64db67f6ea5ceb00c9795ae19b0/requests_file-3.0.0-py2.py3-none-any.whl", hash = "sha256:aca222ec94a19310be2a0ed6bdcdebb09058b0f6c3e984af56361c8fca59653c", size = 4486 }, +] + [[package]] name = "rich" version = "14.0.0" @@ -3086,6 +3100,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/de/a8/8f499c179ec900783ffe133e9aab10044481679bb9aad78436d239eee716/tiktoken-0.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:5ea0edb6f83dc56d794723286215918c1cde03712cbbafa0348b33448faf5b95", size = 894669 }, ] +[[package]] +name = "tldextract" +version = "5.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "idna" }, + { name = "requests" }, + { name = "requests-file" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/97/78/182641ea38e3cfd56e9c7b3c0d48a53d432eea755003aa544af96403d4ac/tldextract-5.3.0.tar.gz", hash = "sha256:b3d2b70a1594a0ecfa6967d57251527d58e00bb5a91a74387baa0d87a0678609", size = 128502 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/7c/ea488ef48f2f544566947ced88541bc45fae9e0e422b2edbf165ee07da99/tldextract-5.3.0-py3-none-any.whl", hash = "sha256:f70f31d10b55c83993f55e91ecb7c5d84532a8972f22ec578ecfbe5ea2292db2", size = 107384 }, +] + [[package]] name = "tokenizers" version = "0.21.1"