From 2004073b8ec34a705933b345191faf560c9d22a7 Mon Sep 17 00:00:00 2001 From: charlesmeng18 Date: Thu, 30 Oct 2025 17:13:33 -0700 Subject: [PATCH 01/26] added tools for searching/booking/add-ons for flights --- scripts/generate_flights.py | 473 ++++++++++++++++++ .../backend/services/airline_chat.py | 21 +- src/airline_agent/constants.py | 7 + src/airline_agent/tools/booking.py | 342 +++++++++++++ src/airline_agent/types/booking.py | 93 ++++ 5 files changed, 933 insertions(+), 3 deletions(-) create mode 100644 scripts/generate_flights.py create mode 100644 src/airline_agent/tools/booking.py create mode 100644 src/airline_agent/types/booking.py diff --git a/scripts/generate_flights.py b/scripts/generate_flights.py new file mode 100644 index 0000000..6277d01 --- /dev/null +++ b/scripts/generate_flights.py @@ -0,0 +1,473 @@ +#!/usr/bin/env python3 +""" +Script to generate Frontier Airlines (F9) flight data for SF Bay Area to New York routes. +Includes direct flights and connecting flights with layovers through hub airports. +Comprehensive coverage for Halloween week 2025 (Oct 31 - Nov 7). +""" + +import json +import random +from datetime import datetime, timedelta +from pathlib import Path + +# San Francisco Bay Area airports +SF_AIRPORTS = ["SFO", "SJC", "OAK"] + +# Common hub airports for layovers between SF and NYC +HUB_AIRPORTS = ["DEN", "ORD", "ATL", "DFW", "LAS", "PHX", "SEA", "IAH", "MSP", "DTW"] + +# New York airports +NYC_AIRPORTS = ["JFK", "EWR", "LGA"] + +# Frontier Airlines only +CARRIER_CODE = "F9" +CARRIER_NAME = "Frontier" + +# Cabin configurations and base prices (Frontier Airlines style) +CABIN_CONFIGS = { + "economy": {"base_price": 120, "price_range": (80, 180)}, + "premium_economy": {"base_price": 200, "price_range": (150, 280)}, + "business": {"base_price": 350, "price_range": (280, 500)}, +} + +# Flight duration estimates (in hours) +# Base durations for SFO routes +BASE_DURATIONS = { + ("SFO", "JFK"): 5.5, + ("SFO", "EWR"): 5.3, + ("SFO", "LGA"): 5.4, + ("SFO", "DEN"): 2.5, + ("SFO", "ORD"): 4.0, + ("SFO", "ATL"): 4.5, + ("SFO", "DFW"): 3.5, + ("SFO", "LAS"): 1.5, + ("SFO", "PHX"): 1.8, + ("SFO", "SEA"): 2.0, + ("SFO", "IAH"): 3.8, + ("SFO", "MSP"): 3.5, + ("SFO", "DTW"): 4.2, +} + +# SJC and OAK are similar to SFO, with slight variations +FLIGHT_DURATIONS = dict(BASE_DURATIONS) +# SJC routes (slightly shorter than SFO) +for (orig, dest), duration in BASE_DURATIONS.items(): + if orig == "SFO": + FLIGHT_DURATIONS[("SJC", dest)] = duration - 0.1 if duration > 2 else duration + if dest == "SFO": + FLIGHT_DURATIONS[(orig, "SJC")] = duration - 0.1 if duration > 2 else duration +# OAK routes (slightly shorter than SFO) +for (orig, dest), duration in BASE_DURATIONS.items(): + if orig == "SFO": + FLIGHT_DURATIONS[("OAK", dest)] = duration - 0.1 if duration > 2 else duration + if dest == "SFO": + FLIGHT_DURATIONS[(orig, "OAK")] = duration - 0.1 if duration > 2 else duration + +# Hub to NYC routes +FLIGHT_DURATIONS.update({ + ("DEN", "JFK"): 3.5, + ("DEN", "EWR"): 3.3, + ("DEN", "LGA"): 3.4, + ("ORD", "JFK"): 2.0, + ("ORD", "EWR"): 2.0, + ("ORD", "LGA"): 2.0, + ("ATL", "JFK"): 2.5, + ("ATL", "EWR"): 2.3, + ("ATL", "LGA"): 2.4, + ("DFW", "JFK"): 3.5, + ("DFW", "EWR"): 3.3, + ("DFW", "LGA"): 3.4, + ("LAS", "JFK"): 5.0, + ("LAS", "EWR"): 4.8, + ("LAS", "LGA"): 4.9, + ("PHX", "JFK"): 4.5, + ("PHX", "EWR"): 4.3, + ("PHX", "LGA"): 4.4, + ("SEA", "JFK"): 5.5, + ("SEA", "EWR"): 5.3, + ("SEA", "LGA"): 5.4, + ("IAH", "JFK"): 3.0, + ("IAH", "EWR"): 2.8, + ("IAH", "LGA"): 2.9, + ("MSP", "JFK"): 2.8, + ("MSP", "EWR"): 2.6, + ("MSP", "LGA"): 2.7, + ("DTW", "JFK"): 1.8, + ("DTW", "EWR"): 1.6, + ("DTW", "LGA"): 1.7, +}) + +# NYC to Hub routes (reverse of hub to NYC) +for (hub, nyc), duration in list(FLIGHT_DURATIONS.items()): + if hub in HUB_AIRPORTS and nyc in NYC_AIRPORTS: + FLIGHT_DURATIONS[(nyc, hub)] = duration + +# NYC to SF routes (reverse of SF to NYC) +for (sf, nyc), duration in list(FLIGHT_DURATIONS.items()): + if sf in SF_AIRPORTS and nyc in NYC_AIRPORTS: + FLIGHT_DURATIONS[(nyc, sf)] = duration + +# SF to Hub routes +for sf in SF_AIRPORTS: + for hub in HUB_AIRPORTS: + if (sf, hub) not in FLIGHT_DURATIONS: + # Use SFO duration as base + base_duration = FLIGHT_DURATIONS.get(("SFO", hub), 3.0) + if sf == "SJC" or sf == "OAK": + FLIGHT_DURATIONS[(sf, hub)] = base_duration - 0.1 if base_duration > 2 else base_duration + else: + FLIGHT_DURATIONS[(sf, hub)] = base_duration + +# Hub to SF routes (reverse) +for (sf, hub), duration in list(FLIGHT_DURATIONS.items()): + if sf in SF_AIRPORTS and hub in HUB_AIRPORTS: + FLIGHT_DURATIONS[(hub, sf)] = duration + +# Timezone offsets (PST/PDT for SF airports, EST/EDT for NYC) +SF_TZ_OFFSET = -8 # PST +NYC_TZ_OFFSET = -5 # EST + +# Hub airport timezones +HUB_TZ_OFFSETS = { + "DEN": -7, # MST + "ORD": -6, # CST + "ATL": -5, # EST + "DFW": -6, # CST + "LAS": -8, # PST + "PHX": -7, # MST + "SEA": -8, # PST + "IAH": -6, # CST + "MSP": -6, # CST + "DTW": -5, # EST +} + + +def get_flight_duration(origin: str, destination: str) -> float: + """Get flight duration in hours for a route.""" + route = (origin, destination) + return FLIGHT_DURATIONS.get(route, 3.0) # Default 3 hours if not found + + +def get_timezone_offset(airport: str) -> int: + """Get timezone offset for an airport.""" + if airport in SF_AIRPORTS: + return SF_TZ_OFFSET + if airport in NYC_AIRPORTS: + return NYC_TZ_OFFSET + return HUB_TZ_OFFSETS.get(airport, -5) + + +def generate_fares() -> list[dict]: + """Generate random fares for a flight with different fare types.""" + fares = [] + + # Always include economy with multiple fare types + economy_config = CABIN_CONFIGS["economy"] + base_price = random.uniform(*economy_config["price_range"]) + + # Basic fare: no bags included + fares.append({ + "cabin": "economy", + "fare_type": "basic", + "price_total": round(base_price, 2), + "currency": "USD", + "seats_available": random.randint(3, 12), + "included_carry_on": False, + "included_checked_bag": False, + }) + + # Standard fare: includes carry-on (+$15-25 more than basic) + standard_price = base_price + random.uniform(15, 25) + fares.append({ + "cabin": "economy", + "fare_type": "standard", + "price_total": round(standard_price, 2), + "currency": "USD", + "seats_available": random.randint(2, 10), + "included_carry_on": True, + "included_checked_bag": False, + }) + + # Flexible fare: includes checked bag (+$30-45 more than basic) + flexible_price = base_price + random.uniform(30, 45) + fares.append({ + "cabin": "economy", + "fare_type": "flexible", + "price_total": round(flexible_price, 2), + "currency": "USD", + "seats_available": random.randint(1, 8), + "included_carry_on": True, + "included_checked_bag": True, + }) + + # Randomly add premium economy (30% chance - Frontier has limited premium options) + if random.random() < 0.3: + premium_config = CABIN_CONFIGS["premium_economy"] + premium_base = random.uniform(*premium_config["price_range"]) + + # Premium economy basic + fares.append({ + "cabin": "premium_economy", + "fare_type": "basic", + "price_total": round(premium_base, 2), + "currency": "USD", + "seats_available": random.randint(2, 6), + "included_carry_on": True, # Premium always includes carry-on + "included_checked_bag": False, + }) + + # Premium economy flexible (with checked bag) + fares.append({ + "cabin": "premium_economy", + "fare_type": "flexible", + "price_total": round(premium_base + random.uniform(20, 35), 2), + "currency": "USD", + "seats_available": random.randint(1, 4), + "included_carry_on": True, + "included_checked_bag": True, + }) + + # Randomly add business (20% chance - Frontier has limited business class) + if random.random() < 0.2: + business_config = CABIN_CONFIGS["business"] + business_base = random.uniform(*business_config["price_range"]) + + # Business class always includes everything + fares.append({ + "cabin": "business", + "fare_type": "flexible", # Business is always flexible + "price_total": round(business_base, 2), + "currency": "USD", + "seats_available": random.randint(1, 4), + "included_carry_on": True, + "included_checked_bag": True, + }) + + # No first class for Frontier Airlines + + return fares + + +def generate_add_ons() -> list[dict]: + """Generate available add-on services for a flight.""" + add_ons = [ + { + "service_type": "checked_bag", + "price": round(random.uniform(30, 40), 2), + "currency": "USD", + "description": "One checked bag (up to 50 lbs, 62 linear inches)", + }, + { + "service_type": "carry_on", + "price": round(random.uniform(20, 30), 2), + "currency": "USD", + "description": "One carry-on bag (personal item included)", + }, + { + "service_type": "seat_selection", + "price": round(random.uniform(10, 25), 2), + "currency": "USD", + "description": "Select your seat in advance", + }, + { + "service_type": "priority_boarding", + "price": round(random.uniform(8, 15), 2), + "currency": "USD", + "description": "Priority boarding (Zone 2)", + }, + { + "service_type": "travel_insurance", + "price": round(random.uniform(15, 30), 2), + "currency": "USD", + "description": "Trip protection insurance", + }, + ] + return add_ons + + +def generate_flight_id(origin: str, destination: str, departure: datetime, carrier: str) -> str: + """Generate a unique flight ID.""" + date_str = departure.strftime("%Y-%m-%dT%H:%M") + return f"{carrier}-{origin}-{destination}-{date_str}" + + +def generate_direct_flights(start_date: datetime, num_days: int = 8, origin_airports: list = None, dest_airports: list = None) -> list[dict]: + """Generate direct flights from origin airports to destination airports.""" + if origin_airports is None: + origin_airports = SF_AIRPORTS + if dest_airports is None: + dest_airports = NYC_AIRPORTS + + flights = [] + + for day in range(num_days): + date = start_date + timedelta(days=day) + + # Generate comprehensive flights - multiple per origin-destination pair + for origin in origin_airports: + for destination in dest_airports: + # Generate 3-6 flights per origin-destination pair per day + num_flights = random.randint(3, 6) + + for flight_num in range(num_flights): + # Random departure time between 6 AM and 10 PM + hour = random.randint(6, 22) + minute = random.choice([0, 15, 30, 45]) + + carrier_code = CARRIER_CODE + + departure_time = date.replace(hour=hour, minute=minute, second=0, microsecond=0) + + # Calculate arrival time + duration = get_flight_duration(origin, destination) + arrival_time = departure_time + timedelta(hours=duration) + + # Adjust for timezone + departure_offset = get_timezone_offset(origin) + arrival_offset = get_timezone_offset(destination) + + departure_str = departure_time.strftime(f"%Y-%m-%dT%H:%M:00{departure_offset:+03d}:00") + arrival_str = arrival_time.strftime(f"%Y-%m-%dT%H:%M:00{arrival_offset:+03d}:00") + + flight = { + "id": generate_flight_id(origin, destination, departure_time, carrier_code), + "origin": origin, + "destination": destination, + "departure": departure_str, + "arrival": arrival_str, + "flight_number": f"{carrier_code} {random.randint(100, 999)}", + "carrier": carrier_code, + "fares": generate_fares(), + "add_ons": generate_add_ons(), + } + + flights.append(flight) + + return flights + + +def generate_connecting_flights(start_date: datetime, num_days: int = 8, origin_airports: list = None, dest_airports: list = None) -> list[dict]: + """Generate connecting flights from origin airports to destination airports via hub airports.""" + if origin_airports is None: + origin_airports = SF_AIRPORTS + if dest_airports is None: + dest_airports = NYC_AIRPORTS + + flights = [] + + for day in range(num_days): + date = start_date + timedelta(days=day) + + # Generate comprehensive connecting routes - all combinations + for origin in origin_airports: + for destination in dest_airports: + # Generate multiple connecting routes through different hubs + # Use all hubs to create many transfer options + for hub in HUB_AIRPORTS: + # Generate 1-3 connecting flights per hub per origin-destination pair per day + num_routes = random.randint(1, 3) + + for _ in range(num_routes): + carrier_code = CARRIER_CODE + + # First leg: Origin -> Hub + hour1 = random.randint(6, 18) + minute1 = random.choice([0, 15, 30, 45]) + departure_time_leg1 = date.replace(hour=hour1, minute=minute1, second=0, microsecond=0) + + duration1 = get_flight_duration(origin, hub) + arrival_time_leg1 = departure_time_leg1 + timedelta(hours=duration1) + + # Layover: 45 minutes to 3 hours + layover_hours = random.choice([0.75, 1.0, 1.5, 2.0, 2.5, 3.0]) + departure_time_leg2 = arrival_time_leg1 + timedelta(hours=layover_hours) + + # Second leg: Hub -> Destination + duration2 = get_flight_duration(hub, destination) + arrival_time_leg2 = departure_time_leg2 + timedelta(hours=duration2) + + # First leg + departure_offset_leg1 = get_timezone_offset(origin) + arrival_offset_leg1 = get_timezone_offset(hub) + + flight1 = { + "id": generate_flight_id(origin, hub, departure_time_leg1, carrier_code), + "origin": origin, + "destination": hub, + "departure": departure_time_leg1.strftime(f"%Y-%m-%dT%H:%M:00{departure_offset_leg1:+03d}:00"), + "arrival": arrival_time_leg1.strftime(f"%Y-%m-%dT%H:%M:00{arrival_offset_leg1:+03d}:00"), + "flight_number": f"{carrier_code} {random.randint(100, 999)}", + "carrier": carrier_code, + "fares": generate_fares(), + "add_ons": generate_add_ons(), + } + + # Second leg + departure_offset_leg2 = get_timezone_offset(hub) + arrival_offset_leg2 = get_timezone_offset(destination) + + flight2 = { + "id": generate_flight_id(hub, destination, departure_time_leg2, carrier_code), + "origin": hub, + "destination": destination, + "departure": departure_time_leg2.strftime(f"%Y-%m-%dT%H:%M:00{departure_offset_leg2:+03d}:00"), + "arrival": arrival_time_leg2.strftime(f"%Y-%m-%dT%H:%M:00{arrival_offset_leg2:+03d}:00"), + "flight_number": f"{carrier_code} {random.randint(100, 999)}", + "carrier": carrier_code, + "fares": generate_fares(), + "add_ons": generate_add_ons(), + } + + flights.extend([flight1, flight2]) + + return flights + + +def main(): + """Main function to generate and save flight data.""" + # Set random seed for reproducibility + random.seed(42) + + # Start date: Halloween 2025 (October 31) + 1 week + start_date = datetime(2025, 10, 31) + num_days = 8 # Oct 31 - Nov 7 + + print("Generating comprehensive flight data for Halloween week 2025 (Oct 31 - Nov 7)...") + + # Generate SF -> NYC flights (direct only) + print("Generating direct flights (SF -> NYC)...") + direct_flights_sf_to_nyc = generate_direct_flights(start_date, num_days=num_days, origin_airports=SF_AIRPORTS, dest_airports=NYC_AIRPORTS) + print(f"Generated {len(direct_flights_sf_to_nyc)} direct flights from SF to NYC") + + # Generate NYC -> SF flights (direct only) + print("Generating direct flights (NYC -> SF)...") + direct_flights_nyc_to_sf = generate_direct_flights(start_date, num_days=num_days, origin_airports=NYC_AIRPORTS, dest_airports=SF_AIRPORTS) + print(f"Generated {len(direct_flights_nyc_to_sf)} direct flights from NYC to SF") + + # Combine all flights (direct only, no transfers) + all_flights = direct_flights_sf_to_nyc + direct_flights_nyc_to_sf + + # Get the project root (two levels up from scripts/) + project_root = Path(__file__).parent.parent + flights_file = project_root / "data" / "flights.json" + + # Sort by departure time + all_flights.sort(key=lambda x: x["departure"]) + + output_data = { + "flights": all_flights + } + + with open(flights_file, "w") as f: + json.dump(output_data, f, indent=2) + + print(f"\n✓ Successfully saved {len(all_flights)} total flights to {flights_file}") + print(f" - Direct flights SF->NYC: {len(direct_flights_sf_to_nyc)}") + print(f" - Direct flights NYC->SF: {len(direct_flights_nyc_to_sf)}") + print(f" - All flights are DIRECT flights for Halloween week 2025 (Oct 31 - Nov 7)") + print(f" - No connecting/transfer flights included") + + +if __name__ == "__main__": + main() + diff --git a/src/airline_agent/backend/services/airline_chat.py b/src/airline_agent/backend/services/airline_chat.py index 352d899..8be7398 100644 --- a/src/airline_agent/backend/services/airline_chat.py +++ b/src/airline_agent/backend/services/airline_chat.py @@ -48,6 +48,7 @@ ) from airline_agent.constants import AGENT_INSTRUCTIONS, AGENT_MODEL from airline_agent.tools.knowledge_base import KnowledgeBase +from airline_agent.tools.booking import BookingTools load_dotenv() @@ -55,13 +56,23 @@ logger.setLevel(logging.INFO) -def create_agent(kb: KnowledgeBase) -> Agent: +def create_agent(kb: KnowledgeBase, booking: BookingTools) -> Agent: """Create the airline support agent.""" model = OpenAIChatModel(model_name=AGENT_MODEL, settings=ModelSettings(temperature=0.0)) return Agent( model=model, instructions=AGENT_INSTRUCTIONS, - tools=[kb.get_article, kb.search, kb.list_directory], + tools=[ + kb.get_article, + kb.search, + kb.list_directory, + booking.search_flights, + booking.get_fare_details, + booking.book_flights, + booking.get_booking, + booking.get_my_bookings, + booking.add_service_to_booking, + ], ) @@ -78,8 +89,12 @@ def get_cleanlab_project() -> Project: kb_path=str(pathlib.Path(__file__).parents[4] / "data/kb.json"), vector_index_path=str(pathlib.Path(__file__).parents[4] / "data/vector-db"), ) +booking = BookingTools( + flights_path=str(pathlib.Path(__file__).parents[4] / "data/flights.json"), + reservations_path=str(pathlib.Path(__file__).parents[4] / "data/reservations.json"), +) project = get_cleanlab_project() -agent = create_agent(kb) +agent = create_agent(kb, booking) thread_to_messages: dict[str, list[ModelMessage]] = {} cleanlab_enabled_by_thread: dict[str, bool] = {} diff --git a/src/airline_agent/constants.py b/src/airline_agent/constants.py index 21be9ee..60acfdd 100644 --- a/src/airline_agent/constants.py +++ b/src/airline_agent/constants.py @@ -18,17 +18,24 @@ - search — find candidate articles by query (keep top-k small, ≤5), returns title/snippet/path. - get_article — get the full article by its path. - list_directory — list directory structure to make more informed searches. +- search_flights — search available flights by origin airport code, destination airport code, and departure date (YYYY-MM-DD format). Always ask for the departure date if the user doesn't provide it. Common city names like "NYC" are automatically mapped to airport codes. +- book_flights — book one or more flights for the current user. Requires list of flight IDs and cabin class (defaults to economy). Returns booking confirmation with booking ID and total price. +- get_booking — retrieve booking details by booking ID. +- get_my_bookings — retrieve all confirmed bookings for the current user. ## Tool Use Guidelines: - Keep it tight: aim for 1-2 calls per turn (hard cap 4). - Answer only from retrieved content. - If a missing detail blocks tool use, ask one short clarifying question. If not blocking, proceed and state your assumption. - Don't dump raw tool output—summarize clearly. +- When booking multiple flights (outbound and return), include all flight IDs in a single book_flights call. ## Response Guidelines: - Answer questions based on information you look up in the knowledge base, not based on your own knowledge. - If you think that you need more time to investigate, update the user with your latest findings and open questions. You can proceed if the user confirms. - Discuss any airline-related topics with the user. +- When a booking is successfully created, provide the booking ID and confirmation details clearly. +- If you book flights, provide the booking ID and summarize the flights booked and total price. - If the user asks about anything unrelated to the airline, politely inform them that you can only assist with airline-related inquiries. """.strip() .replace("\n", " ") diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py new file mode 100644 index 0000000..cbc92d5 --- /dev/null +++ b/src/airline_agent/tools/booking.py @@ -0,0 +1,342 @@ +from __future__ import annotations + +import json +import uuid +from datetime import date, datetime +from pathlib import Path + +from airline_agent.types.booking import ( + Booking, + Flight, + FlightBooking, + BookingStatus, + ServiceAddOn, + ServiceAddOnOption, + FareType, +) + + +class BookingTools: + def __init__(self, flights_path: str, reservations_path: str | None = None): + with open(flights_path) as f: + raw = json.load(f) + self._flights: dict[str, Flight] = {x["id"]: Flight(**x) for x in raw["flights"]} + + # Initialize reservations storage + self._reservations_path = reservations_path + self._reservations: dict[str, Booking] = {} + if reservations_path: + self._load_reservations() + + def _load_reservations(self) -> None: + """Load reservations from JSON file.""" + if not self._reservations_path: + return + try: + reservations_file = Path(self._reservations_path) + if reservations_file.exists(): + with open(reservations_file) as f: + data = json.load(f) + self._reservations = { + bid: Booking(**booking_data) + for bid, booking_data in data.items() + } + else: + # Create empty file if it doesn't exist + reservations_file.parent.mkdir(parents=True, exist_ok=True) + self._save_reservations() + except (FileNotFoundError, json.JSONDecodeError): + self._reservations = {} + + def _save_reservations(self) -> None: + """Save reservations to JSON file.""" + if not self._reservations_path: + return + data = { + bid: booking.model_dump(mode="json") + for bid, booking in self._reservations.items() + } + with open(self._reservations_path, "w") as f: + json.dump(data, f, indent=2, default=str) + + def search_flights(self, origin: str, destination: str, departure_date: str) -> list[Flight]: + """ + Search available flights by route and date. + + Args: + origin: IATA airport code (e.g., "SFO", "JFK") + destination: IATA airport code (e.g., "SFO", "JFK") + departure_date: Date in YYYY-MM-DD format + + Returns: + List of available flights matching the route and date + """ + try: + dep = date.fromisoformat(departure_date) + except Exception as e: # noqa: BLE001 + raise ValueError(f"Invalid departure_date: {departure_date}") from e + + return [ + fl + for fl in self._flights.values() + if fl.origin == origin + and fl.destination == destination + and fl.departure.date().isoformat() == dep.isoformat() + ] + + def get_fare_details(self, flight_id: str, cabin: str, fare_type: str = "basic") -> dict: + """ + Get detailed fare information including what's included and available add-ons. + + Args: + flight_id: The flight ID + cabin: Cabin class (economy, premium_economy, business, first) + fare_type: Fare type (basic, standard, flexible) + + Returns: + Dictionary with fare details including included services and available add-ons + """ + if flight_id not in self._flights: + raise ValueError(f"Flight not found: {flight_id}") + + flight = self._flights[flight_id] + + # Find the fare for the requested cabin and fare type + fare = next( + (f for f in flight.fares if f.cabin == cabin and f.fare_type == fare_type), + None + ) + if not fare: + available_fares = [(f.cabin, f.fare_type) for f in flight.fares] + raise ValueError( + f"Fare '{fare_type}' in '{cabin}' cabin not available for flight {flight_id}. " + f"Available fares: {available_fares}" + ) + + included_services = [] + if fare.included_carry_on: + included_services.append("carry_on") + if fare.included_checked_bag: + included_services.append("checked_bag") + + return { + "flight_id": flight_id, + "cabin": cabin, + "fare_type": fare_type, + "price": fare.price_total, + "currency": fare.currency, + "seats_available": fare.seats_available, + "included_services": included_services, + "available_add_ons": [ + { + "service_type": addon.service_type, + "price": addon.price, + "currency": addon.currency, + "description": addon.description, + } + for addon in flight.add_ons + ], + } + + def book_flights( + self, + flight_ids: list[str], + cabin: str = "economy", + fare_type: str = "basic", + ) -> Booking: + """ + Book one or more flights for the current user. + + Args: + flight_ids: List of flight IDs to book + cabin: Cabin class (economy, premium_economy, business, first) + fare_type: Fare type (basic, standard, flexible). Defaults to "basic". + + Returns: + The created booking with booking ID and total price + """ + if not flight_ids: + raise ValueError("At least one flight ID must be provided") + + now = datetime.now() + booking_id = f"BK-{uuid.uuid4().hex[:8].upper()}" + + flight_bookings: list[FlightBooking] = [] + currency = "USD" + + for flight_id in flight_ids: + if flight_id not in self._flights: + raise ValueError(f"Flight not found: {flight_id}") + + flight = self._flights[flight_id] + + # Find the fare for the requested cabin and fare type + fare = next( + (f for f in flight.fares if f.cabin == cabin and f.fare_type == fare_type), + None + ) + if not fare: + available_fares = [(f.cabin, f.fare_type) for f in flight.fares] + raise ValueError( + f"Fare '{fare_type}' in '{cabin}' cabin not available for flight {flight_id}. " + f"Available fares: {available_fares}" + ) + + if fare.seats_available <= 0: + raise ValueError( + f"No seats available for fare '{fare_type}' in {cabin} cabin for flight {flight_id}" + ) + + # Determine included services + included_services = [] + if fare.included_carry_on: + included_services.append("carry_on") + if fare.included_checked_bag: + included_services.append("checked_bag") + + flight_bookings.append( + FlightBooking( + flight_id=flight_id, + cabin=cabin, + fare_type=fare_type, + base_price=fare.price_total, + currency=fare.currency, + included_services=included_services, + add_ons=[], + ) + ) + currency = fare.currency # Use currency from last flight + + booking = Booking( + booking_id=booking_id, + flights=flight_bookings, + currency=currency, + status=BookingStatus( + status="confirmed", + created_at=now, + updated_at=now, + ), + ) + + self._reservations[booking_id] = booking + self._save_reservations() + + return booking + + def get_booking(self, booking_id: str) -> Booking: + """ + Retrieve a booking by its booking ID. + + Args: + booking_id: The booking ID (e.g., "BK-12345678") + + Returns: + The booking details + """ + if booking_id not in self._reservations: + raise ValueError(f"Booking not found: {booking_id}") + return self._reservations[booking_id] + + def get_my_bookings(self) -> list[Booking]: + """ + Retrieve all confirmed bookings for the current user. + + Returns: + List of all confirmed bookings + """ + return [ + booking + for booking in self._reservations.values() + if booking.status.status == "confirmed" + ] + + def add_service_to_booking( + self, + booking_id: str, + flight_id: str, + service_type: str, + ) -> Booking: + """ + Add a service (e.g., checked bag, carry-on) to an existing booking. + Updates the booking's total price and updated_at timestamp. + + Args: + booking_id: The booking ID (e.g., "BK-12345678") + flight_id: The flight ID within the booking to add the service to + service_type: Type of service to add (checked_bag, carry_on, seat_selection, etc.) + + Returns: + The updated booking with the new service added + """ + if booking_id not in self._reservations: + raise ValueError(f"Booking not found: {booking_id}") + + booking = self._reservations[booking_id] + + # Find the flight in the booking + flight_booking = next( + (fb for fb in booking.flights if fb.flight_id == flight_id), + None + ) + if not flight_booking: + available_flights = [fb.flight_id for fb in booking.flights] + raise ValueError( + f"Flight {flight_id} not found in booking {booking_id}. " + f"Available flights: {available_flights}" + ) + + # Get the flight to check available add-ons + if flight_id not in self._flights: + raise ValueError(f"Flight not found: {flight_id}") + + flight = self._flights[flight_id] + + # Find the add-on option + addon_option = next( + (ao for ao in flight.add_ons if ao.service_type == service_type), + None + ) + if not addon_option: + available_addons = [ao.service_type for ao in flight.add_ons] + raise ValueError( + f"Service '{service_type}' not available for flight {flight_id}. " + f"Available add-ons: {available_addons}" + ) + + # Check if service is already included in the fare + if service_type in flight_booking.included_services: + raise ValueError( + f"Service '{service_type}' is already included in the {flight_booking.fare_type} fare " + f"for flight {flight_id}" + ) + + # Check if add-on already exists + existing_addon = next( + (ao for ao in flight_booking.add_ons if ao.service_type == service_type), + None + ) + if existing_addon: + raise ValueError( + f"Service '{service_type}' has already been added to flight {flight_id} in this booking" + ) + + # Add the service add-on + now = datetime.now() + flight_booking.add_ons.append( + ServiceAddOn( + service_type=service_type, + price=addon_option.price, + currency=addon_option.currency, + added_at=now, + ) + ) + + # Update booking timestamp + booking.status.updated_at = now + + # Save the updated booking + self._reservations[booking_id] = booking + self._save_reservations() + + return booking + diff --git a/src/airline_agent/types/booking.py b/src/airline_agent/types/booking.py new file mode 100644 index 0000000..3cf36c4 --- /dev/null +++ b/src/airline_agent/types/booking.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, computed_field + + +Cabin = Literal["economy", "premium_economy", "business", "first"] +FareType = Literal["basic", "standard", "flexible"] +ServiceType = Literal["checked_bag", "carry_on", "seat_selection", "priority_boarding", "travel_insurance"] + + +class ServiceAddOn(BaseModel): + """A service add-on purchased for a flight.""" + service_type: ServiceType + price: float + currency: str = "USD" + added_at: datetime # When was this add-on added + + +class Fare(BaseModel): + cabin: Cabin + fare_type: FareType = "basic" # Which fare bundle (basic, standard, flexible) + price_total: float # per passenger + currency: str = "USD" + seats_available: int + included_carry_on: bool = False # Does this fare include carry-on? + included_checked_bag: bool = False # Does this fare include checked bag? + + +class ServiceAddOnOption(BaseModel): + """Available service add-on options for a flight.""" + service_type: ServiceType + price: float + currency: str = "USD" + description: str = "" + + +class Flight(BaseModel): + id: str + origin: str + destination: str + departure: datetime + arrival: datetime + flight_number: str + carrier: str = "F9" + fares: list[Fare] + add_ons: list[ServiceAddOnOption] = [] # Available add-ons for this flight + + +class BookingStatus(BaseModel): + status: Literal["confirmed", "cancelled", "pending"] + created_at: datetime + updated_at: datetime + + +class FlightBooking(BaseModel): + """Represents a single flight within a booking.""" + flight_id: str + cabin: Cabin + fare_type: FareType = "basic" # Which fare bundle was purchased + + # Base fare pricing + base_price: float # Price of the fare itself + currency: str = "USD" + + # Services included in the fare (e.g., "carry_on" for standard, "checked_bag" for flexible) + included_services: list[str] = [] # e.g., ["carry_on"] or ["checked_bag"] + + # Add-on services purchased separately + add_ons: list[ServiceAddOn] = [] + + @computed_field + @property + def price_total(self) -> float: + """Calculate total price including base fare + add-ons.""" + return self.base_price + sum(addon.price for addon in self.add_ons) + + +class Booking(BaseModel): + """Represents a complete booking containing one or more flights.""" + booking_id: str + flights: list[FlightBooking] + currency: str = "USD" + status: BookingStatus + + @computed_field + @property + def total_price(self) -> float: + """Calculate total price across all flights.""" + return sum(flight.price_total for flight in self.flights) + From 4ad8012c4b49cf257013d57c1b6b95ac24f8984d Mon Sep 17 00:00:00 2001 From: charlesmeng18 Date: Fri, 31 Oct 2025 16:55:09 -0700 Subject: [PATCH 02/26] Add fare types and baggage options to booking system --- scripts/generate_flights.py | 313 ++++++++++-------- .../backend/services/airline_chat.py | 3 +- src/airline_agent/tools/booking.py | 141 ++++---- src/airline_agent/types/booking.py | 23 +- 4 files changed, 252 insertions(+), 228 deletions(-) diff --git a/scripts/generate_flights.py b/scripts/generate_flights.py index 6277d01..c5a261f 100644 --- a/scripts/generate_flights.py +++ b/scripts/generate_flights.py @@ -64,38 +64,40 @@ FLIGHT_DURATIONS[(orig, "OAK")] = duration - 0.1 if duration > 2 else duration # Hub to NYC routes -FLIGHT_DURATIONS.update({ - ("DEN", "JFK"): 3.5, - ("DEN", "EWR"): 3.3, - ("DEN", "LGA"): 3.4, - ("ORD", "JFK"): 2.0, - ("ORD", "EWR"): 2.0, - ("ORD", "LGA"): 2.0, - ("ATL", "JFK"): 2.5, - ("ATL", "EWR"): 2.3, - ("ATL", "LGA"): 2.4, - ("DFW", "JFK"): 3.5, - ("DFW", "EWR"): 3.3, - ("DFW", "LGA"): 3.4, - ("LAS", "JFK"): 5.0, - ("LAS", "EWR"): 4.8, - ("LAS", "LGA"): 4.9, - ("PHX", "JFK"): 4.5, - ("PHX", "EWR"): 4.3, - ("PHX", "LGA"): 4.4, - ("SEA", "JFK"): 5.5, - ("SEA", "EWR"): 5.3, - ("SEA", "LGA"): 5.4, - ("IAH", "JFK"): 3.0, - ("IAH", "EWR"): 2.8, - ("IAH", "LGA"): 2.9, - ("MSP", "JFK"): 2.8, - ("MSP", "EWR"): 2.6, - ("MSP", "LGA"): 2.7, - ("DTW", "JFK"): 1.8, - ("DTW", "EWR"): 1.6, - ("DTW", "LGA"): 1.7, -}) +FLIGHT_DURATIONS.update( + { + ("DEN", "JFK"): 3.5, + ("DEN", "EWR"): 3.3, + ("DEN", "LGA"): 3.4, + ("ORD", "JFK"): 2.0, + ("ORD", "EWR"): 2.0, + ("ORD", "LGA"): 2.0, + ("ATL", "JFK"): 2.5, + ("ATL", "EWR"): 2.3, + ("ATL", "LGA"): 2.4, + ("DFW", "JFK"): 3.5, + ("DFW", "EWR"): 3.3, + ("DFW", "LGA"): 3.4, + ("LAS", "JFK"): 5.0, + ("LAS", "EWR"): 4.8, + ("LAS", "LGA"): 4.9, + ("PHX", "JFK"): 4.5, + ("PHX", "EWR"): 4.3, + ("PHX", "LGA"): 4.4, + ("SEA", "JFK"): 5.5, + ("SEA", "EWR"): 5.3, + ("SEA", "LGA"): 5.4, + ("IAH", "JFK"): 3.0, + ("IAH", "EWR"): 2.8, + ("IAH", "LGA"): 2.9, + ("MSP", "JFK"): 2.8, + ("MSP", "EWR"): 2.6, + ("MSP", "LGA"): 2.7, + ("DTW", "JFK"): 1.8, + ("DTW", "EWR"): 1.6, + ("DTW", "LGA"): 1.7, + } +) # NYC to Hub routes (reverse of hub to NYC) for (hub, nyc), duration in list(FLIGHT_DURATIONS.items()): @@ -160,91 +162,103 @@ def get_timezone_offset(airport: str) -> int: def generate_fares() -> list[dict]: """Generate random fares for a flight with different fare types.""" fares = [] - + # Always include economy with multiple fare types economy_config = CABIN_CONFIGS["economy"] base_price = random.uniform(*economy_config["price_range"]) - + # Basic fare: no bags included - fares.append({ - "cabin": "economy", - "fare_type": "basic", - "price_total": round(base_price, 2), - "currency": "USD", - "seats_available": random.randint(3, 12), - "included_carry_on": False, - "included_checked_bag": False, - }) - + fares.append( + { + "cabin": "economy", + "fare_type": "basic", + "price_total": round(base_price, 2), + "currency": "USD", + "seats_available": random.randint(3, 12), + "included_carry_on": False, + "included_checked_bag": False, + } + ) + # Standard fare: includes carry-on (+$15-25 more than basic) standard_price = base_price + random.uniform(15, 25) - fares.append({ - "cabin": "economy", - "fare_type": "standard", - "price_total": round(standard_price, 2), - "currency": "USD", - "seats_available": random.randint(2, 10), - "included_carry_on": True, - "included_checked_bag": False, - }) - + fares.append( + { + "cabin": "economy", + "fare_type": "standard", + "price_total": round(standard_price, 2), + "currency": "USD", + "seats_available": random.randint(2, 10), + "included_carry_on": True, + "included_checked_bag": False, + } + ) + # Flexible fare: includes checked bag (+$30-45 more than basic) flexible_price = base_price + random.uniform(30, 45) - fares.append({ - "cabin": "economy", - "fare_type": "flexible", - "price_total": round(flexible_price, 2), - "currency": "USD", - "seats_available": random.randint(1, 8), - "included_carry_on": True, - "included_checked_bag": True, - }) - + fares.append( + { + "cabin": "economy", + "fare_type": "flexible", + "price_total": round(flexible_price, 2), + "currency": "USD", + "seats_available": random.randint(1, 8), + "included_carry_on": True, + "included_checked_bag": True, + } + ) + # Randomly add premium economy (30% chance - Frontier has limited premium options) if random.random() < 0.3: premium_config = CABIN_CONFIGS["premium_economy"] premium_base = random.uniform(*premium_config["price_range"]) - + # Premium economy basic - fares.append({ - "cabin": "premium_economy", - "fare_type": "basic", - "price_total": round(premium_base, 2), - "currency": "USD", - "seats_available": random.randint(2, 6), - "included_carry_on": True, # Premium always includes carry-on - "included_checked_bag": False, - }) - + fares.append( + { + "cabin": "premium_economy", + "fare_type": "basic", + "price_total": round(premium_base, 2), + "currency": "USD", + "seats_available": random.randint(2, 6), + "included_carry_on": True, # Premium always includes carry-on + "included_checked_bag": False, + } + ) + # Premium economy flexible (with checked bag) - fares.append({ - "cabin": "premium_economy", - "fare_type": "flexible", - "price_total": round(premium_base + random.uniform(20, 35), 2), - "currency": "USD", - "seats_available": random.randint(1, 4), - "included_carry_on": True, - "included_checked_bag": True, - }) - + fares.append( + { + "cabin": "premium_economy", + "fare_type": "flexible", + "price_total": round(premium_base + random.uniform(20, 35), 2), + "currency": "USD", + "seats_available": random.randint(1, 4), + "included_carry_on": True, + "included_checked_bag": True, + } + ) + # Randomly add business (20% chance - Frontier has limited business class) if random.random() < 0.2: business_config = CABIN_CONFIGS["business"] business_base = random.uniform(*business_config["price_range"]) - + # Business class always includes everything - fares.append({ - "cabin": "business", - "fare_type": "flexible", # Business is always flexible - "price_total": round(business_base, 2), - "currency": "USD", - "seats_available": random.randint(1, 4), - "included_carry_on": True, - "included_checked_bag": True, - }) - + fares.append( + { + "cabin": "business", + "fare_type": "flexible", # Business is always flexible + "price_total": round(business_base, 2), + "currency": "USD", + "seats_available": random.randint(1, 4), + "included_carry_on": True, + "included_checked_bag": True, + } + ) + # No first class for Frontier Airlines - + return fares @@ -291,44 +305,46 @@ def generate_flight_id(origin: str, destination: str, departure: datetime, carri return f"{carrier}-{origin}-{destination}-{date_str}" -def generate_direct_flights(start_date: datetime, num_days: int = 8, origin_airports: list = None, dest_airports: list = None) -> list[dict]: +def generate_direct_flights( + start_date: datetime, num_days: int = 8, origin_airports: list = None, dest_airports: list = None +) -> list[dict]: """Generate direct flights from origin airports to destination airports.""" if origin_airports is None: origin_airports = SF_AIRPORTS if dest_airports is None: dest_airports = NYC_AIRPORTS - + flights = [] - + for day in range(num_days): date = start_date + timedelta(days=day) - + # Generate comprehensive flights - multiple per origin-destination pair for origin in origin_airports: for destination in dest_airports: # Generate 3-6 flights per origin-destination pair per day num_flights = random.randint(3, 6) - + for flight_num in range(num_flights): # Random departure time between 6 AM and 10 PM hour = random.randint(6, 22) minute = random.choice([0, 15, 30, 45]) - + carrier_code = CARRIER_CODE - + departure_time = date.replace(hour=hour, minute=minute, second=0, microsecond=0) - + # Calculate arrival time duration = get_flight_duration(origin, destination) arrival_time = departure_time + timedelta(hours=duration) - + # Adjust for timezone departure_offset = get_timezone_offset(origin) arrival_offset = get_timezone_offset(destination) - + departure_str = departure_time.strftime(f"%Y-%m-%dT%H:%M:00{departure_offset:+03d}:00") arrival_str = arrival_time.strftime(f"%Y-%m-%dT%H:%M:00{arrival_offset:+03d}:00") - + flight = { "id": generate_flight_id(origin, destination, departure_time, carrier_code), "origin": origin, @@ -340,24 +356,26 @@ def generate_direct_flights(start_date: datetime, num_days: int = 8, origin_airp "fares": generate_fares(), "add_ons": generate_add_ons(), } - + flights.append(flight) - + return flights -def generate_connecting_flights(start_date: datetime, num_days: int = 8, origin_airports: list = None, dest_airports: list = None) -> list[dict]: +def generate_connecting_flights( + start_date: datetime, num_days: int = 8, origin_airports: list = None, dest_airports: list = None +) -> list[dict]: """Generate connecting flights from origin airports to destination airports via hub airports.""" if origin_airports is None: origin_airports = SF_AIRPORTS if dest_airports is None: dest_airports = NYC_AIRPORTS - + flights = [] - + for day in range(num_days): date = start_date + timedelta(days=day) - + # Generate comprehensive connecting routes - all combinations for origin in origin_airports: for destination in dest_airports: @@ -366,60 +384,64 @@ def generate_connecting_flights(start_date: datetime, num_days: int = 8, origin_ for hub in HUB_AIRPORTS: # Generate 1-3 connecting flights per hub per origin-destination pair per day num_routes = random.randint(1, 3) - + for _ in range(num_routes): carrier_code = CARRIER_CODE - + # First leg: Origin -> Hub hour1 = random.randint(6, 18) minute1 = random.choice([0, 15, 30, 45]) departure_time_leg1 = date.replace(hour=hour1, minute=minute1, second=0, microsecond=0) - + duration1 = get_flight_duration(origin, hub) arrival_time_leg1 = departure_time_leg1 + timedelta(hours=duration1) - + # Layover: 45 minutes to 3 hours layover_hours = random.choice([0.75, 1.0, 1.5, 2.0, 2.5, 3.0]) departure_time_leg2 = arrival_time_leg1 + timedelta(hours=layover_hours) - + # Second leg: Hub -> Destination duration2 = get_flight_duration(hub, destination) arrival_time_leg2 = departure_time_leg2 + timedelta(hours=duration2) - + # First leg departure_offset_leg1 = get_timezone_offset(origin) arrival_offset_leg1 = get_timezone_offset(hub) - + flight1 = { "id": generate_flight_id(origin, hub, departure_time_leg1, carrier_code), "origin": origin, "destination": hub, - "departure": departure_time_leg1.strftime(f"%Y-%m-%dT%H:%M:00{departure_offset_leg1:+03d}:00"), + "departure": departure_time_leg1.strftime( + f"%Y-%m-%dT%H:%M:00{departure_offset_leg1:+03d}:00" + ), "arrival": arrival_time_leg1.strftime(f"%Y-%m-%dT%H:%M:00{arrival_offset_leg1:+03d}:00"), "flight_number": f"{carrier_code} {random.randint(100, 999)}", "carrier": carrier_code, "fares": generate_fares(), "add_ons": generate_add_ons(), } - + # Second leg departure_offset_leg2 = get_timezone_offset(hub) arrival_offset_leg2 = get_timezone_offset(destination) - + flight2 = { "id": generate_flight_id(hub, destination, departure_time_leg2, carrier_code), "origin": hub, "destination": destination, - "departure": departure_time_leg2.strftime(f"%Y-%m-%dT%H:%M:00{departure_offset_leg2:+03d}:00"), + "departure": departure_time_leg2.strftime( + f"%Y-%m-%dT%H:%M:00{departure_offset_leg2:+03d}:00" + ), "arrival": arrival_time_leg2.strftime(f"%Y-%m-%dT%H:%M:00{arrival_offset_leg2:+03d}:00"), "flight_number": f"{carrier_code} {random.randint(100, 999)}", "carrier": carrier_code, "fares": generate_fares(), "add_ons": generate_add_ons(), } - + flights.extend([flight1, flight2]) - + return flights @@ -427,47 +449,48 @@ def main(): """Main function to generate and save flight data.""" # Set random seed for reproducibility random.seed(42) - + # Start date: Halloween 2025 (October 31) + 1 week start_date = datetime(2025, 10, 31) num_days = 8 # Oct 31 - Nov 7 - + print("Generating comprehensive flight data for Halloween week 2025 (Oct 31 - Nov 7)...") - + # Generate SF -> NYC flights (direct only) print("Generating direct flights (SF -> NYC)...") - direct_flights_sf_to_nyc = generate_direct_flights(start_date, num_days=num_days, origin_airports=SF_AIRPORTS, dest_airports=NYC_AIRPORTS) + direct_flights_sf_to_nyc = generate_direct_flights( + start_date, num_days=num_days, origin_airports=SF_AIRPORTS, dest_airports=NYC_AIRPORTS + ) print(f"Generated {len(direct_flights_sf_to_nyc)} direct flights from SF to NYC") - + # Generate NYC -> SF flights (direct only) print("Generating direct flights (NYC -> SF)...") - direct_flights_nyc_to_sf = generate_direct_flights(start_date, num_days=num_days, origin_airports=NYC_AIRPORTS, dest_airports=SF_AIRPORTS) + direct_flights_nyc_to_sf = generate_direct_flights( + start_date, num_days=num_days, origin_airports=NYC_AIRPORTS, dest_airports=SF_AIRPORTS + ) print(f"Generated {len(direct_flights_nyc_to_sf)} direct flights from NYC to SF") - + # Combine all flights (direct only, no transfers) all_flights = direct_flights_sf_to_nyc + direct_flights_nyc_to_sf - + # Get the project root (two levels up from scripts/) project_root = Path(__file__).parent.parent flights_file = project_root / "data" / "flights.json" - + # Sort by departure time all_flights.sort(key=lambda x: x["departure"]) - - output_data = { - "flights": all_flights - } - + + output_data = {"flights": all_flights} + with open(flights_file, "w") as f: json.dump(output_data, f, indent=2) - + print(f"\n✓ Successfully saved {len(all_flights)} total flights to {flights_file}") print(f" - Direct flights SF->NYC: {len(direct_flights_sf_to_nyc)}") print(f" - Direct flights NYC->SF: {len(direct_flights_nyc_to_sf)}") - print(f" - All flights are DIRECT flights for Halloween week 2025 (Oct 31 - Nov 7)") - print(f" - No connecting/transfer flights included") + print(" - All flights are DIRECT flights for Halloween week 2025 (Oct 31 - Nov 7)") + print(" - No connecting/transfer flights included") if __name__ == "__main__": main() - diff --git a/src/airline_agent/backend/services/airline_chat.py b/src/airline_agent/backend/services/airline_chat.py index 8be7398..9eb428e 100644 --- a/src/airline_agent/backend/services/airline_chat.py +++ b/src/airline_agent/backend/services/airline_chat.py @@ -47,8 +47,8 @@ run_cleanlab_validation_logging_tools, ) from airline_agent.constants import AGENT_INSTRUCTIONS, AGENT_MODEL -from airline_agent.tools.knowledge_base import KnowledgeBase from airline_agent.tools.booking import BookingTools +from airline_agent.tools.knowledge_base import KnowledgeBase load_dotenv() @@ -66,6 +66,7 @@ def create_agent(kb: KnowledgeBase, booking: BookingTools) -> Agent: kb.get_article, kb.search, kb.list_directory, + booking.get_current_date, booking.search_flights, booking.get_fare_details, booking.book_flights, diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index cbc92d5..3b2937b 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -4,15 +4,17 @@ import uuid from datetime import date, datetime from pathlib import Path +from typing import Any, cast from airline_agent.types.booking import ( Booking, + BookingStatus, + Cabin, + FareType, Flight, FlightBooking, - BookingStatus, ServiceAddOn, - ServiceAddOnOption, - FareType, + ServiceType, ) @@ -21,7 +23,7 @@ def __init__(self, flights_path: str, reservations_path: str | None = None): with open(flights_path) as f: raw = json.load(f) self._flights: dict[str, Flight] = {x["id"]: Flight(**x) for x in raw["flights"]} - + # Initialize reservations storage self._reservations_path = reservations_path self._reservations: dict[str, Booking] = {} @@ -37,10 +39,7 @@ def _load_reservations(self) -> None: if reservations_file.exists(): with open(reservations_file) as f: data = json.load(f) - self._reservations = { - bid: Booking(**booking_data) - for bid, booking_data in data.items() - } + self._reservations = {bid: Booking(**booking_data) for bid, booking_data in data.items()} else: # Create empty file if it doesn't exist reservations_file.parent.mkdir(parents=True, exist_ok=True) @@ -52,10 +51,7 @@ def _save_reservations(self) -> None: """Save reservations to JSON file.""" if not self._reservations_path: return - data = { - bid: booking.model_dump(mode="json") - for bid, booking in self._reservations.items() - } + data = {bid: booking.model_dump(mode="json") for bid, booking in self._reservations.items()} with open(self._reservations_path, "w") as f: json.dump(data, f, indent=2, default=str) @@ -73,7 +69,7 @@ def search_flights(self, origin: str, destination: str, departure_date: str) -> """ try: dep = date.fromisoformat(departure_date) - except Exception as e: # noqa: BLE001 + except Exception as e: raise ValueError(f"Invalid departure_date: {departure_date}") from e return [ @@ -84,7 +80,7 @@ def search_flights(self, origin: str, destination: str, departure_date: str) -> and fl.departure.date().isoformat() == dep.isoformat() ] - def get_fare_details(self, flight_id: str, cabin: str, fare_type: str = "basic") -> dict: + def get_fare_details(self, flight_id: str, cabin: str, fare_type: str = "basic") -> dict[str, Any]: """ Get detailed fare information including what's included and available add-ons. @@ -98,27 +94,24 @@ def get_fare_details(self, flight_id: str, cabin: str, fare_type: str = "basic") """ if flight_id not in self._flights: raise ValueError(f"Flight not found: {flight_id}") - + flight = self._flights[flight_id] - + # Find the fare for the requested cabin and fare type - fare = next( - (f for f in flight.fares if f.cabin == cabin and f.fare_type == fare_type), - None - ) + fare = next((f for f in flight.fares if f.cabin == cabin and f.fare_type == fare_type), None) if not fare: available_fares = [(f.cabin, f.fare_type) for f in flight.fares] raise ValueError( f"Fare '{fare_type}' in '{cabin}' cabin not available for flight {flight_id}. " f"Available fares: {available_fares}" ) - + included_services = [] if fare.included_carry_on: included_services.append("carry_on") if fare.included_checked_bag: included_services.append("checked_bag") - + return { "flight_id": flight_id, "cabin": cabin, @@ -160,45 +153,40 @@ def book_flights( now = datetime.now() booking_id = f"BK-{uuid.uuid4().hex[:8].upper()}" - + flight_bookings: list[FlightBooking] = [] currency = "USD" for flight_id in flight_ids: if flight_id not in self._flights: raise ValueError(f"Flight not found: {flight_id}") - + flight = self._flights[flight_id] - + # Find the fare for the requested cabin and fare type - fare = next( - (f for f in flight.fares if f.cabin == cabin and f.fare_type == fare_type), - None - ) + fare = next((f for f in flight.fares if f.cabin == cabin and f.fare_type == fare_type), None) if not fare: available_fares = [(f.cabin, f.fare_type) for f in flight.fares] raise ValueError( f"Fare '{fare_type}' in '{cabin}' cabin not available for flight {flight_id}. " f"Available fares: {available_fares}" ) - + if fare.seats_available <= 0: - raise ValueError( - f"No seats available for fare '{fare_type}' in {cabin} cabin for flight {flight_id}" - ) - + raise ValueError(f"No seats available for fare '{fare_type}' in {cabin} cabin for flight {flight_id}") + # Determine included services included_services = [] if fare.included_carry_on: included_services.append("carry_on") if fare.included_checked_bag: included_services.append("checked_bag") - + flight_bookings.append( FlightBooking( flight_id=flight_id, - cabin=cabin, - fare_type=fare_type, + cabin=cast(Cabin, cabin), + fare_type=cast(FareType, fare_type), base_price=fare.price_total, currency=fare.currency, included_services=included_services, @@ -220,7 +208,7 @@ def book_flights( self._reservations[booking_id] = booking self._save_reservations() - + return booking def get_booking(self, booking_id: str) -> Booking: @@ -244,99 +232,106 @@ def get_my_bookings(self) -> list[Booking]: Returns: List of all confirmed bookings """ - return [ - booking - for booking in self._reservations.values() - if booking.status.status == "confirmed" - ] + return [booking for booking in self._reservations.values() if booking.status.status == "confirmed"] def add_service_to_booking( self, booking_id: str, flight_id: str, service_type: str, + seat_preference: str | None = None, + seat_assignment: str | None = None, ) -> Booking: """ - Add a service (e.g., checked bag, carry-on) to an existing booking. + Add a service (e.g., checked bag, carry-on, seat selection) to an existing booking. Updates the booking's total price and updated_at timestamp. Args: booking_id: The booking ID (e.g., "BK-12345678") flight_id: The flight ID within the booking to add the service to service_type: Type of service to add (checked_bag, carry_on, seat_selection, etc.) + seat_preference: For seat_selection, preference like "window", "aisle", "middle" (optional) + seat_assignment: For seat_selection, actual assigned seat like "12A", "15F" (optional) Returns: The updated booking with the new service added """ if booking_id not in self._reservations: raise ValueError(f"Booking not found: {booking_id}") - + booking = self._reservations[booking_id] - + # Find the flight in the booking - flight_booking = next( - (fb for fb in booking.flights if fb.flight_id == flight_id), - None - ) + flight_booking = next((fb for fb in booking.flights if fb.flight_id == flight_id), None) if not flight_booking: available_flights = [fb.flight_id for fb in booking.flights] raise ValueError( - f"Flight {flight_id} not found in booking {booking_id}. " - f"Available flights: {available_flights}" + f"Flight {flight_id} not found in booking {booking_id}. " f"Available flights: {available_flights}" ) - + # Get the flight to check available add-ons if flight_id not in self._flights: raise ValueError(f"Flight not found: {flight_id}") - + flight = self._flights[flight_id] - + # Find the add-on option - addon_option = next( - (ao for ao in flight.add_ons if ao.service_type == service_type), - None - ) + addon_option = next((ao for ao in flight.add_ons if ao.service_type == service_type), None) if not addon_option: available_addons = [ao.service_type for ao in flight.add_ons] raise ValueError( f"Service '{service_type}' not available for flight {flight_id}. " f"Available add-ons: {available_addons}" ) - + # Check if service is already included in the fare if service_type in flight_booking.included_services: raise ValueError( f"Service '{service_type}' is already included in the {flight_booking.fare_type} fare " f"for flight {flight_id}" ) - + # Check if add-on already exists - existing_addon = next( - (ao for ao in flight_booking.add_ons if ao.service_type == service_type), - None - ) + existing_addon = next((ao for ao in flight_booking.add_ons if ao.service_type == service_type), None) if existing_addon: - raise ValueError( - f"Service '{service_type}' has already been added to flight {flight_id} in this booking" - ) - + raise ValueError(f"Service '{service_type}' has already been added to flight {flight_id} in this booking") + # Add the service add-on now = datetime.now() + + # For seat selection, validate that preferences/assignments are only set for seat_selection + if service_type != "seat_selection" and (seat_preference or seat_assignment): + raise ValueError("seat_preference and seat_assignment can only be set for seat_selection service type") + flight_booking.add_ons.append( ServiceAddOn( - service_type=service_type, + service_type=cast(ServiceType, service_type), price=addon_option.price, currency=addon_option.currency, added_at=now, + seat_preference=seat_preference, + seat_assignment=seat_assignment, ) ) - + # Update booking timestamp booking.status.updated_at = now - + # Save the updated booking self._reservations[booking_id] = booking self._save_reservations() - + return booking + def get_current_date(self) -> dict[str, str]: + """ + Get the current date and time. + + Returns: + Dictionary with current date in YYYY-MM-DD format and ISO timestamp + """ + now = datetime.now() + return { + "date": now.date().isoformat(), # YYYY-MM-DD format + "datetime": now.isoformat(), # Full ISO timestamp with timezone + } diff --git a/src/airline_agent/types/booking.py b/src/airline_agent/types/booking.py index 3cf36c4..69c3c45 100644 --- a/src/airline_agent/types/booking.py +++ b/src/airline_agent/types/booking.py @@ -5,7 +5,6 @@ from pydantic import BaseModel, computed_field - Cabin = Literal["economy", "premium_economy", "business", "first"] FareType = Literal["basic", "standard", "flexible"] ServiceType = Literal["checked_bag", "carry_on", "seat_selection", "priority_boarding", "travel_insurance"] @@ -13,10 +12,14 @@ class ServiceAddOn(BaseModel): """A service add-on purchased for a flight.""" + service_type: ServiceType price: float currency: str = "USD" added_at: datetime # When was this add-on added + # Seat selection specific fields + seat_preference: str | None = None # e.g., "window", "aisle", "middle" + seat_assignment: str | None = None # e.g., "12A", "15F" - actual assigned seat class Fare(BaseModel): @@ -31,6 +34,7 @@ class Fare(BaseModel): class ServiceAddOnOption(BaseModel): """Available service add-on options for a flight.""" + service_type: ServiceType price: float currency: str = "USD" @@ -57,21 +61,22 @@ class BookingStatus(BaseModel): class FlightBooking(BaseModel): """Represents a single flight within a booking.""" + flight_id: str cabin: Cabin fare_type: FareType = "basic" # Which fare bundle was purchased - + # Base fare pricing base_price: float # Price of the fare itself currency: str = "USD" - + # Services included in the fare (e.g., "carry_on" for standard, "checked_bag" for flexible) included_services: list[str] = [] # e.g., ["carry_on"] or ["checked_bag"] - + # Add-on services purchased separately add_ons: list[ServiceAddOn] = [] - - @computed_field + + @computed_field # type: ignore[prop-decorator] @property def price_total(self) -> float: """Calculate total price including base fare + add-ons.""" @@ -80,14 +85,14 @@ def price_total(self) -> float: class Booking(BaseModel): """Represents a complete booking containing one or more flights.""" + booking_id: str flights: list[FlightBooking] currency: str = "USD" status: BookingStatus - - @computed_field + + @computed_field # type: ignore[prop-decorator] @property def total_price(self) -> float: """Calculate total price across all flights.""" return sum(flight.price_total for flight in self.flights) - From 1c26515efbfc87cb113dd9416dab4621cf6195b9 Mon Sep 17 00:00:00 2001 From: charlesmeng18 Date: Fri, 31 Oct 2025 17:18:39 -0700 Subject: [PATCH 03/26] added check into reservations and flight status lookup --- .../backend/services/airline_chat.py | 3 + src/airline_agent/tools/booking.py | 237 +++++++++++++++++- src/airline_agent/types/booking.py | 28 +++ 3 files changed, 267 insertions(+), 1 deletion(-) diff --git a/src/airline_agent/backend/services/airline_chat.py b/src/airline_agent/backend/services/airline_chat.py index 9eb428e..c5b944c 100644 --- a/src/airline_agent/backend/services/airline_chat.py +++ b/src/airline_agent/backend/services/airline_chat.py @@ -73,6 +73,9 @@ def create_agent(kb: KnowledgeBase, booking: BookingTools) -> Agent: booking.get_booking, booking.get_my_bookings, booking.add_service_to_booking, + booking.check_in, + booking.get_flight_timings, + booking.get_flight_status, ], ) diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 3b2937b..205ae05 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -1,8 +1,9 @@ from __future__ import annotations import json +import random import uuid -from datetime import date, datetime +from datetime import date, datetime, timedelta from pathlib import Path from typing import Any, cast @@ -13,6 +14,7 @@ FareType, Flight, FlightBooking, + FlightStatus, ServiceAddOn, ServiceType, ) @@ -20,6 +22,7 @@ class BookingTools: def __init__(self, flights_path: str, reservations_path: str | None = None): + self._flights_path = flights_path with open(flights_path) as f: raw = json.load(f) self._flights: dict[str, Flight] = {x["id"]: Flight(**x) for x in raw["flights"]} @@ -55,6 +58,12 @@ def _save_reservations(self) -> None: with open(self._reservations_path, "w") as f: json.dump(data, f, indent=2, default=str) + def _save_flights(self) -> None: + """Save flights to JSON file.""" + data = {"flights": [flight.model_dump(mode="json") for flight in self._flights.values()]} + with open(self._flights_path, "w") as f: + json.dump(data, f, indent=2, default=str) + def search_flights(self, origin: str, destination: str, departure_date: str) -> list[Flight]: """ Search available flights by route and date. @@ -335,3 +344,229 @@ def get_current_date(self) -> dict[str, str]: "date": now.date().isoformat(), # YYYY-MM-DD format "datetime": now.isoformat(), # Full ISO timestamp with timezone } + + def _assign_seat(self, flight_booking: FlightBooking, cabin: Cabin, flight_id: str) -> str: + """Assign a seat to a flight booking based on preferences or randomly.""" + # Check if seat_selection add-on exists with an assignment + seat_addon = next( + (addon for addon in flight_booking.add_ons if addon.service_type == "seat_selection"), + None, + ) + if seat_addon and seat_addon.seat_assignment: + return seat_addon.seat_assignment + + # If seat selection exists with preference, try to honor it + preference = seat_addon.seat_preference if seat_addon else None + + # Generate random seat assignment + # Rows vary by cabin: economy 1-40, premium 5-25, business 1-10, first 1-4 + row_ranges = { + "economy": (1, 40), + "premium_economy": (5, 25), + "business": (1, 10), + "first": (1, 4), + } + row_min, row_max = row_ranges.get(cabin, (1, 40)) + row = random.randint(row_min, row_max) + + # Seat letters (typical 3-3 configuration: A, B, C, D, E, F) + seat_letters = ["A", "B", "C", "D", "E", "F"] + window_seats = ["A", "F"] + aisle_seats = ["C", "D"] + + if preference == "window": + seat_letter = random.choice(window_seats) + elif preference == "aisle": + seat_letter = random.choice(aisle_seats) + else: + seat_letter = random.choice(seat_letters) + + return f"{row}{seat_letter}" + + def _assign_gates_and_terminals(self, flight: Flight) -> None: + """Assign gates and terminals to a flight if not already assigned.""" + # Assign departure terminal and gate if not already assigned + if not flight.departure_terminal: + terminals = ["Terminal 1", "Terminal 2", "Terminal 3", "Terminal A", "Terminal B"] + flight.departure_terminal = random.choice(terminals) + + if not flight.departure_gate: + # Generate a gate like "A15", "B22", "C8" + gate_letters = ["A", "B", "C", "D"] + gate_letter = random.choice(gate_letters) + gate_number = random.randint(1, 50) + flight.departure_gate = f"{gate_letter}{gate_number}" + + # Assign arrival terminal and gate if not already assigned + if not flight.arrival_terminal: + terminals = ["Terminal 1", "Terminal 2", "Terminal 3", "Terminal A", "Terminal B"] + flight.arrival_terminal = random.choice(terminals) + + if not flight.arrival_gate: + gate_letters = ["A", "B", "C", "D", "E"] + gate_letter = random.choice(gate_letters) + gate_number = random.randint(1, 60) + flight.arrival_gate = f"{gate_letter}{gate_number}" + + def _calculate_check_in_timings(self, departure: datetime) -> dict[str, datetime]: + """Calculate check-in and boarding timing windows.""" + check_in_opens = departure - timedelta(days=1) # 24 hours before + check_in_closes = departure - timedelta(minutes=45) # 45 minutes before + boarding_starts = departure - timedelta(minutes=30) # 30 minutes before + doors_close = departure - timedelta(minutes=15) # 15 minutes before + + return { + "check_in_opens_at": check_in_opens, + "check_in_closes_at": check_in_closes, + "boarding_starts_at": boarding_starts, + "doors_close_at": doors_close, + } + + def check_in(self, booking_id: str, flight_id: str) -> Booking: + """ + Check in for a specific flight in a booking. + + Args: + booking_id: The booking ID (e.g., "BK-12345678") + flight_id: The flight ID within the booking to check in for + + Returns: + The updated booking with check-in information + """ + if booking_id not in self._reservations: + raise ValueError(f"Booking not found: {booking_id}") + + booking = self._reservations[booking_id] + if booking.status.status != "confirmed": + raise ValueError(f"Cannot check in for booking {booking_id}: booking status is {booking.status.status}") + + # Find the flight in the booking + flight_booking = next((fb for fb in booking.flights if fb.flight_id == flight_id), None) + if not flight_booking: + available_flights = [fb.flight_id for fb in booking.flights] + raise ValueError( + f"Flight {flight_id} not found in booking {booking_id}. Available flights: {available_flights}" + ) + + if flight_booking.checked_in: + raise ValueError(f"Already checked in for flight {flight_id} in booking {booking_id}") + + # Get the flight details + if flight_id not in self._flights: + raise ValueError(f"Flight not found: {flight_id}") + + flight = self._flights[flight_id] + now = datetime.now(flight.departure.tzinfo) if flight.departure.tzinfo else datetime.now() + + # Assign gates and terminals if needed + self._assign_gates_and_terminals(flight) + + # Assign seat if not already assigned + if not flight_booking.seat_assignment: + flight_booking.seat_assignment = self._assign_seat(flight_booking, flight_booking.cabin, flight_id) + + # Update check-in status + flight_booking.checked_in = True + flight_booking.checked_in_at = now + + # Update booking timestamp + booking.status.updated_at = now + + # Save changes + self._reservations[booking_id] = booking + self._save_reservations() + self._save_flights() + + return booking + + def get_flight_timings(self, flight_id: str) -> dict[str, Any]: + """ + Get all timing windows for a flight (check-in, boarding, doors close, etc.). + + Args: + flight_id: The flight ID + + Returns: + Dictionary with all timing windows and estimated times + """ + if flight_id not in self._flights: + raise ValueError(f"Flight not found: {flight_id}") + + flight = self._flights[flight_id] + timings = self._calculate_check_in_timings(flight.departure) + + return { + "flight_id": flight_id, + "flight_number": flight.flight_number, + "origin": flight.origin, + "destination": flight.destination, + "check_in_opens_at": timings["check_in_opens_at"].isoformat(), + "check_in_closes_at": timings["check_in_closes_at"].isoformat(), + "boarding_starts_at": timings["boarding_starts_at"].isoformat(), + "doors_close_at": timings["doors_close_at"].isoformat(), + "scheduled_departure": flight.departure.isoformat(), + "scheduled_arrival": flight.arrival.isoformat(), + "estimated_departure": ( + (flight.departure + timedelta(minutes=flight.delay_minutes or 0)).isoformat() + if flight.delay_minutes + else None + ), + "estimated_arrival": ( + (flight.arrival + timedelta(minutes=flight.delay_minutes or 0)).isoformat() + if flight.delay_minutes + else None + ), + } + + def get_flight_status(self, flight_id: str) -> dict[str, Any]: + """ + Get current flight status including gates, terminals, delays, etc. + + Args: + flight_id: The flight ID + + Returns: + Dictionary with current flight status and operational information + """ + if flight_id not in self._flights: + raise ValueError(f"Flight not found: {flight_id}") + + flight = self._flights[flight_id] + + # Auto-update gates/terminals if check-in window is open + self._assign_gates_and_terminals(flight) + + # Update flight status based on current time + now = datetime.now(flight.departure.tzinfo) if flight.departure.tzinfo else datetime.now() + time_until_departure = flight.departure - now + + # Update status based on current time vs scheduled departure + if flight.status in ("scheduled", "on_time", "boarding"): + if time_until_departure.total_seconds() < -900: # 15 minutes past departure + flight.status = "departed" + elif time_until_departure.total_seconds() < 0: # Past departure time + flight.status = "departed" + elif time_until_departure.total_seconds() < 900: # Less than 15 minutes until departure + flight.status = "boarding" + elif time_until_departure.total_seconds() < 1800: # Less than 30 minutes until departure + flight.status = "on_time" + else: + flight.status = "on_time" + flight.status_updated_at = now + + return { + "flight_id": flight_id, + "flight_number": flight.flight_number, + "origin": flight.origin, + "destination": flight.destination, + "status": flight.status, + "status_updated_at": flight.status_updated_at.isoformat() if flight.status_updated_at else None, + "delay_minutes": flight.delay_minutes, + "departure_terminal": flight.departure_terminal, + "departure_gate": flight.departure_gate, + "arrival_terminal": flight.arrival_terminal, + "arrival_gate": flight.arrival_gate, + "scheduled_departure": flight.departure.isoformat(), + "scheduled_arrival": flight.arrival.isoformat(), + "carrier": flight.carrier, + } diff --git a/src/airline_agent/types/booking.py b/src/airline_agent/types/booking.py index 69c3c45..821ace6 100644 --- a/src/airline_agent/types/booking.py +++ b/src/airline_agent/types/booking.py @@ -8,6 +8,15 @@ Cabin = Literal["economy", "premium_economy", "business", "first"] FareType = Literal["basic", "standard", "flexible"] ServiceType = Literal["checked_bag", "carry_on", "seat_selection", "priority_boarding", "travel_insurance"] +FlightStatus = Literal[ + "scheduled", + "on_time", + "delayed", + "boarding", + "departed", + "arrived", + "cancelled", +] class ServiceAddOn(BaseModel): @@ -52,6 +61,17 @@ class Flight(BaseModel): fares: list[Fare] add_ons: list[ServiceAddOnOption] = [] # Available add-ons for this flight + # Day-of travel information (enriched closer to departure) + departure_terminal: str | None = None # e.g., "Terminal 1", "Terminal A" + departure_gate: str | None = None # e.g., "A15", "B22" + arrival_terminal: str | None = None # e.g., "Terminal 3" + arrival_gate: str | None = None # e.g., "C8" + + # Flight status tracking + status: FlightStatus = "scheduled" + status_updated_at: datetime | None = None + delay_minutes: int | None = None # If delayed, minutes of delay + class BookingStatus(BaseModel): status: Literal["confirmed", "cancelled", "pending"] @@ -76,6 +96,14 @@ class FlightBooking(BaseModel): # Add-on services purchased separately add_ons: list[ServiceAddOn] = [] + # Check-in information + checked_in: bool = False + checked_in_at: datetime | None = None # When check-in was completed + + # Final seat assignment (may differ from seat_selection preference/addon) + # This is the actual assigned seat after check-in + seat_assignment: str | None = None # e.g., "12A", "15F" + @computed_field # type: ignore[prop-decorator] @property def price_total(self) -> float: From 0240c057324c8ee97126ac08c9f3601ac8b23ade Mon Sep 17 00:00:00 2001 From: charlesmeng18 Date: Mon, 3 Nov 2025 09:48:33 -0800 Subject: [PATCH 04/26] fixed formatting and linting --- scripts/generate_flights.py | 111 +++++++++++++++---------- src/airline_agent/tools/booking.py | 125 +++++++++++++++++------------ src/airline_agent/types/booking.py | 6 +- 3 files changed, 145 insertions(+), 97 deletions(-) mode change 100644 => 100755 scripts/generate_flights.py diff --git a/scripts/generate_flights.py b/scripts/generate_flights.py old mode 100644 new mode 100755 index c5a261f..1ca8a06 --- a/scripts/generate_flights.py +++ b/scripts/generate_flights.py @@ -7,9 +7,16 @@ import json import random -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from pathlib import Path +# Constants +SHORT_FLIGHT_THRESHOLD_HOURS = 2.0 # Threshold for short flights (hours) +DURATION_ADJUSTMENT = 0.1 # Adjustment for SJC/OAK flights (hours) +PROBABILITY_PREMIUM_ECONOMY = 0.3 # Probability of premium economy add-ons +PROBABILITY_BUSINESS = 0.2 # Probability of business add-ons +SF_BAY_AIRPORTS = {"SJC", "OAK"} # SF Bay Area airports (excluding SFO) + # San Francisco Bay Area airports SF_AIRPORTS = ["SFO", "SJC", "OAK"] @@ -53,15 +60,27 @@ # SJC routes (slightly shorter than SFO) for (orig, dest), duration in BASE_DURATIONS.items(): if orig == "SFO": - FLIGHT_DURATIONS[("SJC", dest)] = duration - 0.1 if duration > 2 else duration + if duration > SHORT_FLIGHT_THRESHOLD_HOURS: + FLIGHT_DURATIONS[("SJC", dest)] = duration - DURATION_ADJUSTMENT + else: + FLIGHT_DURATIONS[("SJC", dest)] = duration if dest == "SFO": - FLIGHT_DURATIONS[(orig, "SJC")] = duration - 0.1 if duration > 2 else duration + if duration > SHORT_FLIGHT_THRESHOLD_HOURS: + FLIGHT_DURATIONS[(orig, "SJC")] = duration - DURATION_ADJUSTMENT + else: + FLIGHT_DURATIONS[(orig, "SJC")] = duration # OAK routes (slightly shorter than SFO) for (orig, dest), duration in BASE_DURATIONS.items(): if orig == "SFO": - FLIGHT_DURATIONS[("OAK", dest)] = duration - 0.1 if duration > 2 else duration + if duration > SHORT_FLIGHT_THRESHOLD_HOURS: + FLIGHT_DURATIONS[("OAK", dest)] = duration - DURATION_ADJUSTMENT + else: + FLIGHT_DURATIONS[("OAK", dest)] = duration if dest == "SFO": - FLIGHT_DURATIONS[(orig, "OAK")] = duration - 0.1 if duration > 2 else duration + if duration > SHORT_FLIGHT_THRESHOLD_HOURS: + FLIGHT_DURATIONS[(orig, "OAK")] = duration - DURATION_ADJUSTMENT + else: + FLIGHT_DURATIONS[(orig, "OAK")] = duration # Hub to NYC routes FLIGHT_DURATIONS.update( @@ -115,8 +134,11 @@ if (sf, hub) not in FLIGHT_DURATIONS: # Use SFO duration as base base_duration = FLIGHT_DURATIONS.get(("SFO", hub), 3.0) - if sf == "SJC" or sf == "OAK": - FLIGHT_DURATIONS[(sf, hub)] = base_duration - 0.1 if base_duration > 2 else base_duration + if sf in SF_BAY_AIRPORTS: + if base_duration > SHORT_FLIGHT_THRESHOLD_HOURS: + FLIGHT_DURATIONS[(sf, hub)] = base_duration - DURATION_ADJUSTMENT + else: + FLIGHT_DURATIONS[(sf, hub)] = base_duration else: FLIGHT_DURATIONS[(sf, hub)] = base_duration @@ -165,7 +187,7 @@ def generate_fares() -> list[dict]: # Always include economy with multiple fare types economy_config = CABIN_CONFIGS["economy"] - base_price = random.uniform(*economy_config["price_range"]) + base_price = random.uniform(*economy_config["price_range"]) # noqa: S311 # Basic fare: no bags included fares.append( @@ -174,44 +196,44 @@ def generate_fares() -> list[dict]: "fare_type": "basic", "price_total": round(base_price, 2), "currency": "USD", - "seats_available": random.randint(3, 12), + "seats_available": random.randint(3, 12), # noqa: S311 "included_carry_on": False, "included_checked_bag": False, } ) # Standard fare: includes carry-on (+$15-25 more than basic) - standard_price = base_price + random.uniform(15, 25) + standard_price = base_price + random.uniform(15, 25) # noqa: S311 fares.append( { "cabin": "economy", "fare_type": "standard", "price_total": round(standard_price, 2), "currency": "USD", - "seats_available": random.randint(2, 10), + "seats_available": random.randint(2, 10), # noqa: S311 "included_carry_on": True, "included_checked_bag": False, } ) # Flexible fare: includes checked bag (+$30-45 more than basic) - flexible_price = base_price + random.uniform(30, 45) + flexible_price = base_price + random.uniform(30, 45) # noqa: S311 fares.append( { "cabin": "economy", "fare_type": "flexible", "price_total": round(flexible_price, 2), "currency": "USD", - "seats_available": random.randint(1, 8), + "seats_available": random.randint(1, 8), # noqa: S311 "included_carry_on": True, "included_checked_bag": True, } ) # Randomly add premium economy (30% chance - Frontier has limited premium options) - if random.random() < 0.3: + if random.random() < PROBABILITY_PREMIUM_ECONOMY: # noqa: S311 premium_config = CABIN_CONFIGS["premium_economy"] - premium_base = random.uniform(*premium_config["price_range"]) + premium_base = random.uniform(*premium_config["price_range"]) # noqa: S311 # Premium economy basic fares.append( @@ -220,7 +242,7 @@ def generate_fares() -> list[dict]: "fare_type": "basic", "price_total": round(premium_base, 2), "currency": "USD", - "seats_available": random.randint(2, 6), + "seats_available": random.randint(2, 6), # noqa: S311 "included_carry_on": True, # Premium always includes carry-on "included_checked_bag": False, } @@ -231,18 +253,18 @@ def generate_fares() -> list[dict]: { "cabin": "premium_economy", "fare_type": "flexible", - "price_total": round(premium_base + random.uniform(20, 35), 2), + "price_total": round(premium_base + random.uniform(20, 35), 2), # noqa: S311 "currency": "USD", - "seats_available": random.randint(1, 4), + "seats_available": random.randint(1, 4), # noqa: S311 "included_carry_on": True, "included_checked_bag": True, } ) # Randomly add business (20% chance - Frontier has limited business class) - if random.random() < 0.2: + if random.random() < PROBABILITY_BUSINESS: # noqa: S311 business_config = CABIN_CONFIGS["business"] - business_base = random.uniform(*business_config["price_range"]) + business_base = random.uniform(*business_config["price_range"]) # noqa: S311 # Business class always includes everything fares.append( @@ -251,7 +273,7 @@ def generate_fares() -> list[dict]: "fare_type": "flexible", # Business is always flexible "price_total": round(business_base, 2), "currency": "USD", - "seats_available": random.randint(1, 4), + "seats_available": random.randint(1, 4), # noqa: S311 "included_carry_on": True, "included_checked_bag": True, } @@ -264,39 +286,38 @@ def generate_fares() -> list[dict]: def generate_add_ons() -> list[dict]: """Generate available add-on services for a flight.""" - add_ons = [ + return [ { "service_type": "checked_bag", - "price": round(random.uniform(30, 40), 2), + "price": round(random.uniform(30, 40), 2), # noqa: S311 "currency": "USD", "description": "One checked bag (up to 50 lbs, 62 linear inches)", }, { "service_type": "carry_on", - "price": round(random.uniform(20, 30), 2), + "price": round(random.uniform(20, 30), 2), # noqa: S311 "currency": "USD", "description": "One carry-on bag (personal item included)", }, { "service_type": "seat_selection", - "price": round(random.uniform(10, 25), 2), + "price": round(random.uniform(10, 25), 2), # noqa: S311 "currency": "USD", "description": "Select your seat in advance", }, { "service_type": "priority_boarding", - "price": round(random.uniform(8, 15), 2), + "price": round(random.uniform(8, 15), 2), # noqa: S311 "currency": "USD", "description": "Priority boarding (Zone 2)", }, { "service_type": "travel_insurance", - "price": round(random.uniform(15, 30), 2), + "price": round(random.uniform(15, 30), 2), # noqa: S311 "currency": "USD", "description": "Trip protection insurance", }, ] - return add_ons def generate_flight_id(origin: str, destination: str, departure: datetime, carrier: str) -> str: @@ -306,7 +327,10 @@ def generate_flight_id(origin: str, destination: str, departure: datetime, carri def generate_direct_flights( - start_date: datetime, num_days: int = 8, origin_airports: list = None, dest_airports: list = None + start_date: datetime, + num_days: int = 8, + origin_airports: list[str] | None = None, + dest_airports: list[str] | None = None, ) -> list[dict]: """Generate direct flights from origin airports to destination airports.""" if origin_airports is None: @@ -323,12 +347,12 @@ def generate_direct_flights( for origin in origin_airports: for destination in dest_airports: # Generate 3-6 flights per origin-destination pair per day - num_flights = random.randint(3, 6) + num_flights = random.randint(3, 6) # noqa: S311 - for flight_num in range(num_flights): + for _ in range(num_flights): # Random departure time between 6 AM and 10 PM - hour = random.randint(6, 22) - minute = random.choice([0, 15, 30, 45]) + hour = random.randint(6, 22) # noqa: S311 + minute = random.choice([0, 15, 30, 45]) # noqa: S311 carrier_code = CARRIER_CODE @@ -351,7 +375,7 @@ def generate_direct_flights( "destination": destination, "departure": departure_str, "arrival": arrival_str, - "flight_number": f"{carrier_code} {random.randint(100, 999)}", + "flight_number": f"{carrier_code} {random.randint(100, 999)}", # noqa: S311 "carrier": carrier_code, "fares": generate_fares(), "add_ons": generate_add_ons(), @@ -363,7 +387,10 @@ def generate_direct_flights( def generate_connecting_flights( - start_date: datetime, num_days: int = 8, origin_airports: list = None, dest_airports: list = None + start_date: datetime, + num_days: int = 8, + origin_airports: list[str] | None = None, + dest_airports: list[str] | None = None, ) -> list[dict]: """Generate connecting flights from origin airports to destination airports via hub airports.""" if origin_airports is None: @@ -383,21 +410,21 @@ def generate_connecting_flights( # Use all hubs to create many transfer options for hub in HUB_AIRPORTS: # Generate 1-3 connecting flights per hub per origin-destination pair per day - num_routes = random.randint(1, 3) + num_routes = random.randint(1, 3) # noqa: S311 for _ in range(num_routes): carrier_code = CARRIER_CODE # First leg: Origin -> Hub - hour1 = random.randint(6, 18) - minute1 = random.choice([0, 15, 30, 45]) + hour1 = random.randint(6, 18) # noqa: S311 + minute1 = random.choice([0, 15, 30, 45]) # noqa: S311 departure_time_leg1 = date.replace(hour=hour1, minute=minute1, second=0, microsecond=0) duration1 = get_flight_duration(origin, hub) arrival_time_leg1 = departure_time_leg1 + timedelta(hours=duration1) # Layover: 45 minutes to 3 hours - layover_hours = random.choice([0.75, 1.0, 1.5, 2.0, 2.5, 3.0]) + layover_hours = random.choice([0.75, 1.0, 1.5, 2.0, 2.5, 3.0]) # noqa: S311 departure_time_leg2 = arrival_time_leg1 + timedelta(hours=layover_hours) # Second leg: Hub -> Destination @@ -416,7 +443,7 @@ def generate_connecting_flights( f"%Y-%m-%dT%H:%M:00{departure_offset_leg1:+03d}:00" ), "arrival": arrival_time_leg1.strftime(f"%Y-%m-%dT%H:%M:00{arrival_offset_leg1:+03d}:00"), - "flight_number": f"{carrier_code} {random.randint(100, 999)}", + "flight_number": f"{carrier_code} {random.randint(100, 999)}", # noqa: S311 "carrier": carrier_code, "fares": generate_fares(), "add_ons": generate_add_ons(), @@ -434,7 +461,7 @@ def generate_connecting_flights( f"%Y-%m-%dT%H:%M:00{departure_offset_leg2:+03d}:00" ), "arrival": arrival_time_leg2.strftime(f"%Y-%m-%dT%H:%M:00{arrival_offset_leg2:+03d}:00"), - "flight_number": f"{carrier_code} {random.randint(100, 999)}", + "flight_number": f"{carrier_code} {random.randint(100, 999)}", # noqa: S311 "carrier": carrier_code, "fares": generate_fares(), "add_ons": generate_add_ons(), @@ -451,7 +478,7 @@ def main(): random.seed(42) # Start date: Halloween 2025 (October 31) + 1 week - start_date = datetime(2025, 10, 31) + start_date = datetime(2025, 10, 31, tzinfo=UTC) num_days = 8 # Oct 31 - Nov 7 print("Generating comprehensive flight data for Halloween week 2025 (Oct 31 - Nov 7)...") diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 205ae05..ece2552 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -3,7 +3,7 @@ import json import random import uuid -from datetime import date, datetime, timedelta +from datetime import UTC, date, datetime, timedelta from pathlib import Path from typing import Any, cast @@ -14,11 +14,15 @@ FareType, Flight, FlightBooking, - FlightStatus, ServiceAddOn, ServiceType, ) +# Constants for flight status timing (in seconds) +DEPARTURE_PAST_THRESHOLD = -900 # 15 minutes past departure +BOARDING_START_THRESHOLD = 900 # 15 minutes until departure +ON_TIME_THRESHOLD = 1800 # 30 minutes until departure + class BookingTools: def __init__(self, flights_path: str, reservations_path: str | None = None): @@ -79,7 +83,8 @@ def search_flights(self, origin: str, destination: str, departure_date: str) -> try: dep = date.fromisoformat(departure_date) except Exception as e: - raise ValueError(f"Invalid departure_date: {departure_date}") from e + msg = f"Invalid departure_date: {departure_date}" + raise ValueError(msg) from e return [ fl @@ -102,7 +107,8 @@ def get_fare_details(self, flight_id: str, cabin: str, fare_type: str = "basic") Dictionary with fare details including included services and available add-ons """ if flight_id not in self._flights: - raise ValueError(f"Flight not found: {flight_id}") + msg = f"Flight not found: {flight_id}" + raise ValueError(msg) flight = self._flights[flight_id] @@ -110,10 +116,11 @@ def get_fare_details(self, flight_id: str, cabin: str, fare_type: str = "basic") fare = next((f for f in flight.fares if f.cabin == cabin and f.fare_type == fare_type), None) if not fare: available_fares = [(f.cabin, f.fare_type) for f in flight.fares] - raise ValueError( + msg = ( f"Fare '{fare_type}' in '{cabin}' cabin not available for flight {flight_id}. " f"Available fares: {available_fares}" ) + raise ValueError(msg) included_services = [] if fare.included_carry_on: @@ -158,9 +165,10 @@ def book_flights( The created booking with booking ID and total price """ if not flight_ids: - raise ValueError("At least one flight ID must be provided") + msg = "At least one flight ID must be provided" + raise ValueError(msg) - now = datetime.now() + now = datetime.now(UTC) booking_id = f"BK-{uuid.uuid4().hex[:8].upper()}" flight_bookings: list[FlightBooking] = [] @@ -168,7 +176,8 @@ def book_flights( for flight_id in flight_ids: if flight_id not in self._flights: - raise ValueError(f"Flight not found: {flight_id}") + msg = f"Flight not found: {flight_id}" + raise ValueError(msg) flight = self._flights[flight_id] @@ -176,13 +185,15 @@ def book_flights( fare = next((f for f in flight.fares if f.cabin == cabin and f.fare_type == fare_type), None) if not fare: available_fares = [(f.cabin, f.fare_type) for f in flight.fares] - raise ValueError( + msg = ( f"Fare '{fare_type}' in '{cabin}' cabin not available for flight {flight_id}. " f"Available fares: {available_fares}" ) + raise ValueError(msg) if fare.seats_available <= 0: - raise ValueError(f"No seats available for fare '{fare_type}' in {cabin} cabin for flight {flight_id}") + msg = f"No seats available for fare '{fare_type}' in {cabin} cabin for flight {flight_id}" + raise ValueError(msg) # Determine included services included_services = [] @@ -231,7 +242,8 @@ def get_booking(self, booking_id: str) -> Booking: The booking details """ if booking_id not in self._reservations: - raise ValueError(f"Booking not found: {booking_id}") + msg = f"Booking not found: {booking_id}" + raise ValueError(msg) return self._reservations[booking_id] def get_my_bookings(self) -> list[Booking]: @@ -266,7 +278,8 @@ def add_service_to_booking( The updated booking with the new service added """ if booking_id not in self._reservations: - raise ValueError(f"Booking not found: {booking_id}") + msg = f"Booking not found: {booking_id}" + raise ValueError(msg) booking = self._reservations[booking_id] @@ -274,13 +287,13 @@ def add_service_to_booking( flight_booking = next((fb for fb in booking.flights if fb.flight_id == flight_id), None) if not flight_booking: available_flights = [fb.flight_id for fb in booking.flights] - raise ValueError( - f"Flight {flight_id} not found in booking {booking_id}. " f"Available flights: {available_flights}" - ) + msg = f"Flight {flight_id} not found in booking {booking_id}. Available flights: {available_flights}" + raise ValueError(msg) # Get the flight to check available add-ons if flight_id not in self._flights: - raise ValueError(f"Flight not found: {flight_id}") + msg = f"Flight not found: {flight_id}" + raise ValueError(msg) flight = self._flights[flight_id] @@ -288,29 +301,32 @@ def add_service_to_booking( addon_option = next((ao for ao in flight.add_ons if ao.service_type == service_type), None) if not addon_option: available_addons = [ao.service_type for ao in flight.add_ons] - raise ValueError( - f"Service '{service_type}' not available for flight {flight_id}. " - f"Available add-ons: {available_addons}" + msg = ( + f"Service '{service_type}' not available for flight {flight_id}. Available add-ons: {available_addons}" ) + raise ValueError(msg) # Check if service is already included in the fare if service_type in flight_booking.included_services: - raise ValueError( + msg = ( f"Service '{service_type}' is already included in the {flight_booking.fare_type} fare " f"for flight {flight_id}" ) + raise ValueError(msg) # Check if add-on already exists existing_addon = next((ao for ao in flight_booking.add_ons if ao.service_type == service_type), None) if existing_addon: - raise ValueError(f"Service '{service_type}' has already been added to flight {flight_id} in this booking") + msg = f"Service '{service_type}' has already been added to flight {flight_id} in this booking" + raise ValueError(msg) # Add the service add-on - now = datetime.now() + now = datetime.now(UTC) # For seat selection, validate that preferences/assignments are only set for seat_selection if service_type != "seat_selection" and (seat_preference or seat_assignment): - raise ValueError("seat_preference and seat_assignment can only be set for seat_selection service type") + msg = "seat_preference and seat_assignment can only be set for seat_selection service type" + raise ValueError(msg) flight_booking.add_ons.append( ServiceAddOn( @@ -339,13 +355,13 @@ def get_current_date(self) -> dict[str, str]: Returns: Dictionary with current date in YYYY-MM-DD format and ISO timestamp """ - now = datetime.now() + now = datetime.now(UTC) return { "date": now.date().isoformat(), # YYYY-MM-DD format "datetime": now.isoformat(), # Full ISO timestamp with timezone } - def _assign_seat(self, flight_booking: FlightBooking, cabin: Cabin, flight_id: str) -> str: + def _assign_seat(self, flight_booking: FlightBooking, cabin: Cabin, _flight_id: str) -> str: """Assign a seat to a flight booking based on preferences or randomly.""" # Check if seat_selection add-on exists with an assignment seat_addon = next( @@ -367,19 +383,19 @@ def _assign_seat(self, flight_booking: FlightBooking, cabin: Cabin, flight_id: s "first": (1, 4), } row_min, row_max = row_ranges.get(cabin, (1, 40)) - row = random.randint(row_min, row_max) + row = random.randint(row_min, row_max) # noqa: S311 # Seat letters (typical 3-3 configuration: A, B, C, D, E, F) seat_letters = ["A", "B", "C", "D", "E", "F"] window_seats = ["A", "F"] aisle_seats = ["C", "D"] - + if preference == "window": - seat_letter = random.choice(window_seats) + seat_letter = random.choice(window_seats) # noqa: S311 elif preference == "aisle": - seat_letter = random.choice(aisle_seats) + seat_letter = random.choice(aisle_seats) # noqa: S311 else: - seat_letter = random.choice(seat_letters) + seat_letter = random.choice(seat_letters) # noqa: S311 return f"{row}{seat_letter}" @@ -388,24 +404,24 @@ def _assign_gates_and_terminals(self, flight: Flight) -> None: # Assign departure terminal and gate if not already assigned if not flight.departure_terminal: terminals = ["Terminal 1", "Terminal 2", "Terminal 3", "Terminal A", "Terminal B"] - flight.departure_terminal = random.choice(terminals) + flight.departure_terminal = random.choice(terminals) # noqa: S311 if not flight.departure_gate: # Generate a gate like "A15", "B22", "C8" gate_letters = ["A", "B", "C", "D"] - gate_letter = random.choice(gate_letters) - gate_number = random.randint(1, 50) + gate_letter = random.choice(gate_letters) # noqa: S311 + gate_number = random.randint(1, 50) # noqa: S311 flight.departure_gate = f"{gate_letter}{gate_number}" # Assign arrival terminal and gate if not already assigned if not flight.arrival_terminal: terminals = ["Terminal 1", "Terminal 2", "Terminal 3", "Terminal A", "Terminal B"] - flight.arrival_terminal = random.choice(terminals) + flight.arrival_terminal = random.choice(terminals) # noqa: S311 if not flight.arrival_gate: gate_letters = ["A", "B", "C", "D", "E"] - gate_letter = random.choice(gate_letters) - gate_number = random.randint(1, 60) + gate_letter = random.choice(gate_letters) # noqa: S311 + gate_number = random.randint(1, 60) # noqa: S311 flight.arrival_gate = f"{gate_letter}{gate_number}" def _calculate_check_in_timings(self, departure: datetime) -> dict[str, datetime]: @@ -434,29 +450,32 @@ def check_in(self, booking_id: str, flight_id: str) -> Booking: The updated booking with check-in information """ if booking_id not in self._reservations: - raise ValueError(f"Booking not found: {booking_id}") + msg = f"Booking not found: {booking_id}" + raise ValueError(msg) booking = self._reservations[booking_id] if booking.status.status != "confirmed": - raise ValueError(f"Cannot check in for booking {booking_id}: booking status is {booking.status.status}") + msg = f"Cannot check in for booking {booking_id}: booking status is {booking.status.status}" + raise ValueError(msg) # Find the flight in the booking flight_booking = next((fb for fb in booking.flights if fb.flight_id == flight_id), None) if not flight_booking: available_flights = [fb.flight_id for fb in booking.flights] - raise ValueError( - f"Flight {flight_id} not found in booking {booking_id}. Available flights: {available_flights}" - ) + msg = f"Flight {flight_id} not found in booking {booking_id}. Available flights: {available_flights}" + raise ValueError(msg) if flight_booking.checked_in: - raise ValueError(f"Already checked in for flight {flight_id} in booking {booking_id}") + msg = f"Already checked in for flight {flight_id} in booking {booking_id}" + raise ValueError(msg) # Get the flight details if flight_id not in self._flights: - raise ValueError(f"Flight not found: {flight_id}") + msg = f"Flight not found: {flight_id}" + raise ValueError(msg) flight = self._flights[flight_id] - now = datetime.now(flight.departure.tzinfo) if flight.departure.tzinfo else datetime.now() + now = datetime.now(flight.departure.tzinfo) if flight.departure.tzinfo else datetime.now(UTC) # Assign gates and terminals if needed self._assign_gates_and_terminals(flight) @@ -490,7 +509,8 @@ def get_flight_timings(self, flight_id: str) -> dict[str, Any]: Dictionary with all timing windows and estimated times """ if flight_id not in self._flights: - raise ValueError(f"Flight not found: {flight_id}") + msg = f"Flight not found: {flight_id}" + raise ValueError(msg) flight = self._flights[flight_id] timings = self._calculate_check_in_timings(flight.departure) @@ -529,26 +549,25 @@ def get_flight_status(self, flight_id: str) -> dict[str, Any]: Dictionary with current flight status and operational information """ if flight_id not in self._flights: - raise ValueError(f"Flight not found: {flight_id}") + msg = f"Flight not found: {flight_id}" + raise ValueError(msg) flight = self._flights[flight_id] # Auto-update gates/terminals if check-in window is open self._assign_gates_and_terminals(flight) - + # Update flight status based on current time - now = datetime.now(flight.departure.tzinfo) if flight.departure.tzinfo else datetime.now() + now = datetime.now(flight.departure.tzinfo) if flight.departure.tzinfo else datetime.now(UTC) time_until_departure = flight.departure - now # Update status based on current time vs scheduled departure if flight.status in ("scheduled", "on_time", "boarding"): - if time_until_departure.total_seconds() < -900: # 15 minutes past departure - flight.status = "departed" - elif time_until_departure.total_seconds() < 0: # Past departure time + if time_until_departure.total_seconds() < 0: # Past departure time flight.status = "departed" - elif time_until_departure.total_seconds() < 900: # Less than 15 minutes until departure + elif time_until_departure.total_seconds() < BOARDING_START_THRESHOLD: flight.status = "boarding" - elif time_until_departure.total_seconds() < 1800: # Less than 30 minutes until departure + elif time_until_departure.total_seconds() < ON_TIME_THRESHOLD: flight.status = "on_time" else: flight.status = "on_time" diff --git a/src/airline_agent/types/booking.py b/src/airline_agent/types/booking.py index 821ace6..260b828 100644 --- a/src/airline_agent/types/booking.py +++ b/src/airline_agent/types/booking.py @@ -1,10 +1,12 @@ from __future__ import annotations -from datetime import datetime -from typing import Literal +from typing import TYPE_CHECKING, Literal from pydantic import BaseModel, computed_field +if TYPE_CHECKING: + from datetime import datetime + Cabin = Literal["economy", "premium_economy", "business", "first"] FareType = Literal["basic", "standard", "flexible"] ServiceType = Literal["checked_bag", "carry_on", "seat_selection", "priority_boarding", "travel_insurance"] From 8807b0b7d980957305821836dd5677ad1586262e Mon Sep 17 00:00:00 2001 From: charlesmeng18 Date: Mon, 3 Nov 2025 11:05:09 -0800 Subject: [PATCH 05/26] updated to reflect Frontier's Fare to Services mapping, and no separate cabin classes --- scripts/generate_flights.py | 168 ++++++++++++++--------------- src/airline_agent/constants.py | 2 +- src/airline_agent/tools/booking.py | 127 +++++++++++----------- src/airline_agent/types/booking.py | 40 ++++--- 4 files changed, 175 insertions(+), 162 deletions(-) diff --git a/scripts/generate_flights.py b/scripts/generate_flights.py index 1ca8a06..ec9157b 100755 --- a/scripts/generate_flights.py +++ b/scripts/generate_flights.py @@ -13,8 +13,6 @@ # Constants SHORT_FLIGHT_THRESHOLD_HOURS = 2.0 # Threshold for short flights (hours) DURATION_ADJUSTMENT = 0.1 # Adjustment for SJC/OAK flights (hours) -PROBABILITY_PREMIUM_ECONOMY = 0.3 # Probability of premium economy add-ons -PROBABILITY_BUSINESS = 0.2 # Probability of business add-ons SF_BAY_AIRPORTS = {"SJC", "OAK"} # SF Bay Area airports (excluding SFO) # San Francisco Bay Area airports @@ -30,11 +28,12 @@ CARRIER_CODE = "F9" CARRIER_NAME = "Frontier" -# Cabin configurations and base prices (Frontier Airlines style) -CABIN_CONFIGS = { - "economy": {"base_price": 120, "price_range": (80, 180)}, - "premium_economy": {"base_price": 200, "price_range": (150, 280)}, - "business": {"base_price": 350, "price_range": (280, 500)}, +# Fare bundle base prices (Frontier Airlines style - no separate cabin classes) +FARE_BASE_PRICES = { + "basic": {"price_range": (80, 150)}, + "economy": {"price_range": (120, 200)}, + "premium": {"price_range": (200, 320)}, + "business": {"price_range": (350, 550)}, } # Flight duration estimates (in hours) @@ -182,104 +181,75 @@ def get_timezone_offset(airport: str) -> int: def generate_fares() -> list[dict]: - """Generate random fares for a flight with different fare types.""" + """Generate random fares for a flight with different fare bundles (Frontier Airlines model).""" fares = [] - # Always include economy with multiple fare types - economy_config = CABIN_CONFIGS["economy"] - base_price = random.uniform(*economy_config["price_range"]) # noqa: S311 - - # Basic fare: no bags included + # Basic fare: no services included + basic_price = random.uniform(*FARE_BASE_PRICES["basic"]["price_range"]) # noqa: S311 fares.append( { - "cabin": "economy", "fare_type": "basic", - "price_total": round(base_price, 2), + "price_total": round(basic_price, 2), "currency": "USD", - "seats_available": random.randint(3, 12), # noqa: S311 - "included_carry_on": False, - "included_checked_bag": False, + "seats_available": random.randint(5, 15), # noqa: S311 + "included_services": [], + "checked_bags_included": 0, } ) - # Standard fare: includes carry-on (+$15-25 more than basic) - standard_price = base_price + random.uniform(15, 25) # noqa: S311 + # Economy bundle: Basic + Carry on, Standard seat selection, Refundability, Change/cancel fee waived + economy_price = random.uniform(*FARE_BASE_PRICES["economy"]["price_range"]) # noqa: S311 fares.append( { - "cabin": "economy", - "fare_type": "standard", - "price_total": round(standard_price, 2), + "fare_type": "economy", + "price_total": round(economy_price, 2), "currency": "USD", - "seats_available": random.randint(2, 10), # noqa: S311 - "included_carry_on": True, - "included_checked_bag": False, + "seats_available": random.randint(3, 12), # noqa: S311 + "included_services": ["carry_on", "standard_seat_selection", "refundability", "change_cancel_fee_waived"], + "checked_bags_included": 0, } ) - # Flexible fare: includes checked bag (+$30-45 more than basic) - flexible_price = base_price + random.uniform(30, 45) # noqa: S311 + # Premium bundle: Economy + Premium seat selection + Priority Boarding + premium_price = random.uniform(*FARE_BASE_PRICES["premium"]["price_range"]) # noqa: S311 fares.append( { - "cabin": "economy", - "fare_type": "flexible", - "price_total": round(flexible_price, 2), + "fare_type": "premium", + "price_total": round(premium_price, 2), "currency": "USD", - "seats_available": random.randint(1, 8), # noqa: S311 - "included_carry_on": True, - "included_checked_bag": True, + "seats_available": random.randint(2, 8), # noqa: S311 + "included_services": [ + "carry_on", + "standard_seat_selection", + "refundability", + "change_cancel_fee_waived", + "premium_seat_selection", + "priority_boarding", + ], + "checked_bags_included": 0, } ) - # Randomly add premium economy (30% chance - Frontier has limited premium options) - if random.random() < PROBABILITY_PREMIUM_ECONOMY: # noqa: S311 - premium_config = CABIN_CONFIGS["premium_economy"] - premium_base = random.uniform(*premium_config["price_range"]) # noqa: S311 - - # Premium economy basic - fares.append( - { - "cabin": "premium_economy", - "fare_type": "basic", - "price_total": round(premium_base, 2), - "currency": "USD", - "seats_available": random.randint(2, 6), # noqa: S311 - "included_carry_on": True, # Premium always includes carry-on - "included_checked_bag": False, - } - ) - - # Premium economy flexible (with checked bag) - fares.append( - { - "cabin": "premium_economy", - "fare_type": "flexible", - "price_total": round(premium_base + random.uniform(20, 35), 2), # noqa: S311 - "currency": "USD", - "seats_available": random.randint(1, 4), # noqa: S311 - "included_carry_on": True, - "included_checked_bag": True, - } - ) - - # Randomly add business (20% chance - Frontier has limited business class) - if random.random() < PROBABILITY_BUSINESS: # noqa: S311 - business_config = CABIN_CONFIGS["business"] - business_base = random.uniform(*business_config["price_range"]) # noqa: S311 - - # Business class always includes everything - fares.append( - { - "cabin": "business", - "fare_type": "flexible", # Business is always flexible - "price_total": round(business_base, 2), - "currency": "USD", - "seats_available": random.randint(1, 4), # noqa: S311 - "included_carry_on": True, - "included_checked_bag": True, - } - ) - - # No first class for Frontier Airlines + # Business bundle: Premium + 2 checked bags + UpFront Plus Seating + business_price = random.uniform(*FARE_BASE_PRICES["business"]["price_range"]) # noqa: S311 + fares.append( + { + "fare_type": "business", + "price_total": round(business_price, 2), + "currency": "USD", + "seats_available": random.randint(1, 4), # noqa: S311 + "included_services": [ + "carry_on", + "standard_seat_selection", + "refundability", + "change_cancel_fee_waived", + "premium_seat_selection", + "priority_boarding", + "upfront_plus_seating", + ], + "checked_bags_included": 2, + } + ) return fares @@ -300,16 +270,28 @@ def generate_add_ons() -> list[dict]: "description": "One carry-on bag (personal item included)", }, { - "service_type": "seat_selection", + "service_type": "standard_seat_selection", "price": round(random.uniform(10, 25), 2), # noqa: S311 "currency": "USD", - "description": "Select your seat in advance", + "description": "Select a standard seat in advance", + }, + { + "service_type": "premium_seat_selection", + "price": round(random.uniform(25, 45), 2), # noqa: S311 + "currency": "USD", + "description": "Select a stretch seat with extra legroom", + }, + { + "service_type": "upfront_plus_seating", + "price": round(random.uniform(50, 100), 2), # noqa: S311 + "currency": "USD", + "description": "UpFront Plus seating in first two rows with guaranteed empty middle seat", }, { "service_type": "priority_boarding", "price": round(random.uniform(8, 15), 2), # noqa: S311 "currency": "USD", - "description": "Priority boarding (Zone 2)", + "description": "Priority boarding with overhead bin space", }, { "service_type": "travel_insurance", @@ -317,6 +299,18 @@ def generate_add_ons() -> list[dict]: "currency": "USD", "description": "Trip protection insurance", }, + { + "service_type": "refundability", + "price": round(random.uniform(30, 60), 2), # noqa: S311 + "currency": "USD", + "description": "Add refundability to your booking", + }, + { + "service_type": "change_cancel_fee_waived", + "price": round(random.uniform(20, 40), 2), # noqa: S311 + "currency": "USD", + "description": "Waive change and cancel fees", + }, ] diff --git a/src/airline_agent/constants.py b/src/airline_agent/constants.py index 60acfdd..984862f 100644 --- a/src/airline_agent/constants.py +++ b/src/airline_agent/constants.py @@ -19,7 +19,7 @@ - get_article — get the full article by its path. - list_directory — list directory structure to make more informed searches. - search_flights — search available flights by origin airport code, destination airport code, and departure date (YYYY-MM-DD format). Always ask for the departure date if the user doesn't provide it. Common city names like "NYC" are automatically mapped to airport codes. -- book_flights — book one or more flights for the current user. Requires list of flight IDs and cabin class (defaults to economy). Returns booking confirmation with booking ID and total price. +- book_flights — book one or more flights for the current user. Requires list of flight IDs and fare bundle type (basic, economy, premium, business; defaults to basic). Returns booking confirmation with booking ID and total price. - get_booking — retrieve booking details by booking ID. - get_my_bookings — retrieve all confirmed bookings for the current user. diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index ece2552..48769e9 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -10,7 +10,6 @@ from airline_agent.types.booking import ( Booking, BookingStatus, - Cabin, FareType, Flight, FlightBooking, @@ -94,14 +93,13 @@ def search_flights(self, origin: str, destination: str, departure_date: str) -> and fl.departure.date().isoformat() == dep.isoformat() ] - def get_fare_details(self, flight_id: str, cabin: str, fare_type: str = "basic") -> dict[str, Any]: + def get_fare_details(self, flight_id: str, fare_type: str = "basic") -> dict[str, Any]: """ Get detailed fare information including what's included and available add-ons. Args: flight_id: The flight ID - cabin: Cabin class (economy, premium_economy, business, first) - fare_type: Fare type (basic, standard, flexible) + fare_type: Fare bundle type (basic, economy, premium, business) Returns: Dictionary with fare details including included services and available add-ons @@ -112,30 +110,21 @@ def get_fare_details(self, flight_id: str, cabin: str, fare_type: str = "basic") flight = self._flights[flight_id] - # Find the fare for the requested cabin and fare type - fare = next((f for f in flight.fares if f.cabin == cabin and f.fare_type == fare_type), None) + # Find the fare for the requested fare type (no cabin classes in Frontier model) + fare = next((f for f in flight.fares if f.fare_type == fare_type), None) if not fare: - available_fares = [(f.cabin, f.fare_type) for f in flight.fares] - msg = ( - f"Fare '{fare_type}' in '{cabin}' cabin not available for flight {flight_id}. " - f"Available fares: {available_fares}" - ) + available_fares = [f.fare_type for f in flight.fares] + msg = f"Fare '{fare_type}' not available for flight {flight_id}. " f"Available fares: {available_fares}" raise ValueError(msg) - included_services = [] - if fare.included_carry_on: - included_services.append("carry_on") - if fare.included_checked_bag: - included_services.append("checked_bag") - return { "flight_id": flight_id, - "cabin": cabin, "fare_type": fare_type, "price": fare.price_total, "currency": fare.currency, "seats_available": fare.seats_available, - "included_services": included_services, + "included_services": fare.included_services, + "checked_bags_included": fare.checked_bags_included, "available_add_ons": [ { "service_type": addon.service_type, @@ -150,7 +139,6 @@ def get_fare_details(self, flight_id: str, cabin: str, fare_type: str = "basic") def book_flights( self, flight_ids: list[str], - cabin: str = "economy", fare_type: str = "basic", ) -> Booking: """ @@ -158,8 +146,7 @@ def book_flights( Args: flight_ids: List of flight IDs to book - cabin: Cabin class (economy, premium_economy, business, first) - fare_type: Fare type (basic, standard, flexible). Defaults to "basic". + fare_type: Fare bundle type (basic, economy, premium, business). Defaults to "basic". Returns: The created booking with booking ID and total price @@ -181,35 +168,25 @@ def book_flights( flight = self._flights[flight_id] - # Find the fare for the requested cabin and fare type - fare = next((f for f in flight.fares if f.cabin == cabin and f.fare_type == fare_type), None) + # Find the fare for the requested fare type (no cabin classes in Frontier model) + fare = next((f for f in flight.fares if f.fare_type == fare_type), None) if not fare: - available_fares = [(f.cabin, f.fare_type) for f in flight.fares] - msg = ( - f"Fare '{fare_type}' in '{cabin}' cabin not available for flight {flight_id}. " - f"Available fares: {available_fares}" - ) + available_fares = [f.fare_type for f in flight.fares] + msg = f"Fare '{fare_type}' not available for flight {flight_id}. " f"Available fares: {available_fares}" raise ValueError(msg) if fare.seats_available <= 0: - msg = f"No seats available for fare '{fare_type}' in {cabin} cabin for flight {flight_id}" + msg = f"No seats available for fare '{fare_type}' for flight {flight_id}" raise ValueError(msg) - # Determine included services - included_services = [] - if fare.included_carry_on: - included_services.append("carry_on") - if fare.included_checked_bag: - included_services.append("checked_bag") - flight_bookings.append( FlightBooking( flight_id=flight_id, - cabin=cast(Cabin, cabin), fare_type=cast(FareType, fare_type), base_price=fare.price_total, currency=fare.currency, - included_services=included_services, + included_services=fare.included_services.copy(), + checked_bags_included=fare.checked_bags_included, add_ons=[], ) ) @@ -270,9 +247,9 @@ def add_service_to_booking( Args: booking_id: The booking ID (e.g., "BK-12345678") flight_id: The flight ID within the booking to add the service to - service_type: Type of service to add (checked_bag, carry_on, seat_selection, etc.) - seat_preference: For seat_selection, preference like "window", "aisle", "middle" (optional) - seat_assignment: For seat_selection, actual assigned seat like "12A", "15F" (optional) + service_type: Type of service to add (checked_bag, carry_on, standard_seat_selection, etc.) + seat_preference: For seat selection services, preference like "window", "aisle", "middle" (optional) + seat_assignment: For seat selection services, actual assigned seat like "12A", "15F" (optional) Returns: The updated booking with the new service added @@ -307,7 +284,15 @@ def add_service_to_booking( raise ValueError(msg) # Check if service is already included in the fare - if service_type in flight_booking.included_services: + # Special handling for checked_bag (tracked via count, not in included_services) + if service_type == "checked_bag": + if flight_booking.checked_bags_included > 0: + msg = ( + f"Checked bag(s) are already included in the {flight_booking.fare_type} fare " + f"for flight {flight_id} ({flight_booking.checked_bags_included} bag(s) included)" + ) + raise ValueError(msg) + elif service_type in flight_booking.included_services: msg = ( f"Service '{service_type}' is already included in the {flight_booking.fare_type} fare " f"for flight {flight_id}" @@ -323,11 +308,21 @@ def add_service_to_booking( # Add the service add-on now = datetime.now(UTC) - # For seat selection, validate that preferences/assignments are only set for seat_selection - if service_type != "seat_selection" and (seat_preference or seat_assignment): - msg = "seat_preference and seat_assignment can only be set for seat_selection service type" + # For seat selection services, validate that preferences/assignments are only set for seat selection + seat_selection_types = ["standard_seat_selection", "premium_seat_selection", "upfront_plus_seating"] + if service_type not in seat_selection_types and (seat_preference or seat_assignment): + msg = "seat_preference and seat_assignment can only be set for seat selection service types" raise ValueError(msg) + # Determine seat type based on service type + seat_type = None + if service_type == "standard_seat_selection": + seat_type = "standard" + elif service_type == "premium_seat_selection": + seat_type = "stretch" + elif service_type == "upfront_plus_seating": + seat_type = "upfront_plus" + flight_booking.add_ons.append( ServiceAddOn( service_type=cast(ServiceType, service_type), @@ -336,6 +331,7 @@ def add_service_to_booking( added_at=now, seat_preference=seat_preference, seat_assignment=seat_assignment, + seat_type=seat_type, ) ) @@ -361,11 +357,12 @@ def get_current_date(self) -> dict[str, str]: "datetime": now.isoformat(), # Full ISO timestamp with timezone } - def _assign_seat(self, flight_booking: FlightBooking, cabin: Cabin, _flight_id: str) -> str: - """Assign a seat to a flight booking based on preferences or randomly.""" - # Check if seat_selection add-on exists with an assignment + def _assign_seat(self, flight_booking: FlightBooking, _flight_id: str) -> str: + """Assign a seat to a flight booking based on preferences, fare type, or randomly.""" + # Check if any seat selection add-on exists with an assignment + seat_selection_types = ["standard_seat_selection", "premium_seat_selection", "upfront_plus_seating"] seat_addon = next( - (addon for addon in flight_booking.add_ons if addon.service_type == "seat_selection"), + (addon for addon in flight_booking.add_ons if addon.service_type in seat_selection_types), None, ) if seat_addon and seat_addon.seat_assignment: @@ -373,16 +370,26 @@ def _assign_seat(self, flight_booking: FlightBooking, cabin: Cabin, _flight_id: # If seat selection exists with preference, try to honor it preference = seat_addon.seat_preference if seat_addon else None + seat_type = seat_addon.seat_type if seat_addon else None + + # Generate random seat assignment based on fare type and seat selection + # UpFront Plus: rows 1-2 (first two rows) + # Stretch: rows 3-15 (premium seating area) + # Standard: rows 16-40 (standard seating area) + + if seat_type == "upfront_plus": + row_min, row_max = (1, 2) + elif seat_type == "stretch": + row_min, row_max = (3, 15) + elif flight_booking.fare_type == "business": + # Business fare includes UpFront Plus seating + row_min, row_max = (1, 2) + elif seat_type == "standard": + row_min, row_max = (16, 40) + else: + # Default to standard seating area + row_min, row_max = (16, 40) - # Generate random seat assignment - # Rows vary by cabin: economy 1-40, premium 5-25, business 1-10, first 1-4 - row_ranges = { - "economy": (1, 40), - "premium_economy": (5, 25), - "business": (1, 10), - "first": (1, 4), - } - row_min, row_max = row_ranges.get(cabin, (1, 40)) row = random.randint(row_min, row_max) # noqa: S311 # Seat letters (typical 3-3 configuration: A, B, C, D, E, F) @@ -482,7 +489,7 @@ def check_in(self, booking_id: str, flight_id: str) -> Booking: # Assign seat if not already assigned if not flight_booking.seat_assignment: - flight_booking.seat_assignment = self._assign_seat(flight_booking, flight_booking.cabin, flight_id) + flight_booking.seat_assignment = self._assign_seat(flight_booking, flight_id) # Update check-in status flight_booking.checked_in = True diff --git a/src/airline_agent/types/booking.py b/src/airline_agent/types/booking.py index 260b828..f9289f6 100644 --- a/src/airline_agent/types/booking.py +++ b/src/airline_agent/types/booking.py @@ -1,15 +1,22 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal +from datetime import datetime # noqa: TCH003 +from typing import Literal from pydantic import BaseModel, computed_field -if TYPE_CHECKING: - from datetime import datetime - -Cabin = Literal["economy", "premium_economy", "business", "first"] -FareType = Literal["basic", "standard", "flexible"] -ServiceType = Literal["checked_bag", "carry_on", "seat_selection", "priority_boarding", "travel_insurance"] +FareType = Literal["basic", "economy", "premium", "business"] +ServiceType = Literal[ + "checked_bag", + "carry_on", + "standard_seat_selection", + "premium_seat_selection", + "upfront_plus_seating", + "priority_boarding", + "travel_insurance", + "refundability", + "change_cancel_fee_waived", +] FlightStatus = Literal[ "scheduled", "on_time", @@ -31,16 +38,21 @@ class ServiceAddOn(BaseModel): # Seat selection specific fields seat_preference: str | None = None # e.g., "window", "aisle", "middle" seat_assignment: str | None = None # e.g., "12A", "15F" - actual assigned seat + seat_type: str | None = None # e.g., "standard", "stretch", "upfront_plus" - type of seat selected class Fare(BaseModel): - cabin: Cabin - fare_type: FareType = "basic" # Which fare bundle (basic, standard, flexible) + """Frontier Airlines fare bundle. No separate cabin classes - all passengers in same cabin.""" + + fare_type: FareType = "basic" # Which fare bundle (basic, economy, premium, business) price_total: float # per passenger currency: str = "USD" seats_available: int - included_carry_on: bool = False # Does this fare include carry-on? - included_checked_bag: bool = False # Does this fare include checked bag? + # Services included in this fare bundle (flat list, no nested references) + included_services: list[ + str + ] # e.g., ["carry_on", "standard_seat_selection", "refundability", "change_cancel_fee_waived"] + checked_bags_included: int = 0 # Number of checked bags included (0, 1, or 2 for business) class ServiceAddOnOption(BaseModel): @@ -85,15 +97,15 @@ class FlightBooking(BaseModel): """Represents a single flight within a booking.""" flight_id: str - cabin: Cabin fare_type: FareType = "basic" # Which fare bundle was purchased # Base fare pricing base_price: float # Price of the fare itself currency: str = "USD" - # Services included in the fare (e.g., "carry_on" for standard, "checked_bag" for flexible) - included_services: list[str] = [] # e.g., ["carry_on"] or ["checked_bag"] + # Services included in the fare (flat list, no nested references) + included_services: list[str] = [] # e.g., ["carry_on", "standard_seat_selection", "refundability"] + checked_bags_included: int = 0 # Number of checked bags included (0, 1, or 2 for business) # Add-on services purchased separately add_ons: list[ServiceAddOn] = [] From 8b90784d9a84dcd8aed5e087aa767b6a46d4593b Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Tue, 4 Nov 2025 11:55:29 -0800 Subject: [PATCH 06/26] Update instructions --- README.md | 6 +++++- pyproject.toml | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 890da5c..a9b6ef6 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,11 @@ This repository contains an example of a production-grade AI Agent that is integ ```console $ hatch run create-vector-database ``` -7. Install frontend dependencies: +7. Generate flight schedule data: + ```console + $ hatch run generate-flights + ``` +8. Install frontend dependencies: ```console $ cd frontend $ npm install diff --git a/pyproject.toml b/pyproject.toml index 1a3c022..e661ead 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ create-demo-project = "python -m airline_agent.cleanlab_utils.create_new_demo_pr open-project = "python -m airline_agent.cleanlab_utils.open_project {args}" fetch-pages = "python -m airline_agent.data_preparation.fetch_pages {args}" create-vector-database = "python -m airline_agent.preprocessing.create_vector_database {args}" +generate-flights = "python scripts/generate_flights.py {args}" backend-server = "python -m airline_agent.backend.app {args}" From 761cbc15626ee5a18f20189ca97045ef07f83adf Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Tue, 4 Nov 2025 13:07:26 -0800 Subject: [PATCH 07/26] Switch from comments to docstrings/Fields --- src/airline_agent/types/booking.py | 108 +++++++++++++++++------------ 1 file changed, 63 insertions(+), 45 deletions(-) diff --git a/src/airline_agent/types/booking.py b/src/airline_agent/types/booking.py index f9289f6..7290fbd 100644 --- a/src/airline_agent/types/booking.py +++ b/src/airline_agent/types/booking.py @@ -1,9 +1,7 @@ -from __future__ import annotations - -from datetime import datetime # noqa: TCH003 +from datetime import datetime from typing import Literal -from pydantic import BaseModel, computed_field +from pydantic import BaseModel, Field, computed_field FareType = Literal["basic", "economy", "premium", "business"] ServiceType = Literal[ @@ -34,25 +32,39 @@ class ServiceAddOn(BaseModel): service_type: ServiceType price: float currency: str = "USD" - added_at: datetime # When was this add-on added - # Seat selection specific fields - seat_preference: str | None = None # e.g., "window", "aisle", "middle" - seat_assignment: str | None = None # e.g., "12A", "15F" - actual assigned seat - seat_type: str | None = None # e.g., "standard", "stretch", "upfront_plus" - type of seat selected + added_at: datetime = Field(..., description="Timestamp when the add-on was added") + seat_preference: str | None = Field( + default=None, + description='Seat preference such as "window", "aisle", or "middle"', + ) + seat_assignment: str | None = Field( + default=None, + description='Assigned seat identifier, for example "12A" or "15F"', + ) + seat_type: str | None = Field( + default=None, + description=('Type of seat selected, for example "standard", "stretch", or ' '"upfront_plus"'), + ) class Fare(BaseModel): """Frontier Airlines fare bundle. No separate cabin classes - all passengers in same cabin.""" - fare_type: FareType = "basic" # Which fare bundle (basic, economy, premium, business) - price_total: float # per passenger + fare_type: FareType = Field( + default="basic", + description="Fare bundle purchased (basic, economy, premium, or business)", + ) + price_total: float = Field(..., description="Per-passenger price of the fare") currency: str = "USD" seats_available: int - # Services included in this fare bundle (flat list, no nested references) - included_services: list[ - str - ] # e.g., ["carry_on", "standard_seat_selection", "refundability", "change_cancel_fee_waived"] - checked_bags_included: int = 0 # Number of checked bags included (0, 1, or 2 for business) + included_services: list[str] = Field( + default_factory=list, + description=("Services included in this fare bundle, e.g. carry-on or seat selection"), + ) + checked_bags_included: int = Field( + default=0, + description="Number of checked bags included (0, 1, or 2 for business)", + ) class ServiceAddOnOption(BaseModel): @@ -65,6 +77,8 @@ class ServiceAddOnOption(BaseModel): class Flight(BaseModel): + """Inventory and day-of-travel information for a specific flight.""" + id: str origin: str destination: str @@ -73,21 +87,24 @@ class Flight(BaseModel): flight_number: str carrier: str = "F9" fares: list[Fare] - add_ons: list[ServiceAddOnOption] = [] # Available add-ons for this flight - - # Day-of travel information (enriched closer to departure) - departure_terminal: str | None = None # e.g., "Terminal 1", "Terminal A" - departure_gate: str | None = None # e.g., "A15", "B22" - arrival_terminal: str | None = None # e.g., "Terminal 3" - arrival_gate: str | None = None # e.g., "C8" - - # Flight status tracking + add_ons: list[ServiceAddOnOption] = Field( + default_factory=list, + description="Available add-ons that can be purchased for the flight", + ) + departure_terminal: str | None = Field( + default=None, description='Departure terminal, e.g. "Terminal 1" or "Terminal A"' + ) + departure_gate: str | None = Field(default=None, description='Departure gate, e.g. "A15" or "B22"') + arrival_terminal: str | None = Field(default=None, description='Arrival terminal, e.g. "Terminal 3"') + arrival_gate: str | None = Field(default=None, description='Arrival gate, e.g. "C8"') status: FlightStatus = "scheduled" status_updated_at: datetime | None = None - delay_minutes: int | None = None # If delayed, minutes of delay + delay_minutes: int | None = Field(default=None, description="Minutes of delay when the flight is delayed") class BookingStatus(BaseModel): + """State and timestamps for a booking.""" + status: Literal["confirmed", "cancelled", "pending"] created_at: datetime updated_at: datetime @@ -97,26 +114,27 @@ class FlightBooking(BaseModel): """Represents a single flight within a booking.""" flight_id: str - fare_type: FareType = "basic" # Which fare bundle was purchased - - # Base fare pricing - base_price: float # Price of the fare itself + fare_type: FareType = Field(default="basic", description="Fare bundle purchased for the flight") + base_price: float = Field(..., description="Price of the fare itself") currency: str = "USD" - - # Services included in the fare (flat list, no nested references) - included_services: list[str] = [] # e.g., ["carry_on", "standard_seat_selection", "refundability"] - checked_bags_included: int = 0 # Number of checked bags included (0, 1, or 2 for business) - - # Add-on services purchased separately - add_ons: list[ServiceAddOn] = [] - - # Check-in information - checked_in: bool = False - checked_in_at: datetime | None = None # When check-in was completed - - # Final seat assignment (may differ from seat_selection preference/addon) - # This is the actual assigned seat after check-in - seat_assignment: str | None = None # e.g., "12A", "15F" + included_services: list[str] = Field( + default_factory=list, + description=("Services bundled with the fare, such as carry-on or standard seat selection"), + ) + checked_bags_included: int = Field( + default=0, + description="Number of checked bags included (0, 1, or 2 for business)", + ) + add_ons: list[ServiceAddOn] = Field( + default_factory=list, + description="Add-on services purchased separately from the base fare", + ) + checked_in: bool = Field(default=False, description="Whether the passenger has completed check-in") + checked_in_at: datetime | None = Field(default=None, description="Timestamp when check-in was completed") + seat_assignment: str | None = Field( + default=None, + description='Final seat assignment, e.g. "12A" or "15F"', + ) @computed_field # type: ignore[prop-decorator] @property From e9f4e8dce757b6b5f8ab0fc85ad252117eac2623 Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Tue, 4 Nov 2025 14:46:58 -0800 Subject: [PATCH 08/26] Update system prompt to match tools --- src/airline_agent/constants.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/airline_agent/constants.py b/src/airline_agent/constants.py index 984862f..5fdcae4 100644 --- a/src/airline_agent/constants.py +++ b/src/airline_agent/constants.py @@ -18,10 +18,16 @@ - search — find candidate articles by query (keep top-k small, ≤5), returns title/snippet/path. - get_article — get the full article by its path. - list_directory — list directory structure to make more informed searches. -- search_flights — search available flights by origin airport code, destination airport code, and departure date (YYYY-MM-DD format). Always ask for the departure date if the user doesn't provide it. Common city names like "NYC" are automatically mapped to airport codes. +- get_current_date — return the current date (YYYY-MM-DD) and timestamp. +- search_flights — search available flights by origin and destination airport codes (IATA) and departure date (YYYY-MM-DD). Always ask for the departure date if the user doesn't provide it. +- get_fare_details — retrieve fare bundle pricing, included services, and add-ons for a specific flight. - book_flights — book one or more flights for the current user. Requires list of flight IDs and fare bundle type (basic, economy, premium, business; defaults to basic). Returns booking confirmation with booking ID and total price. - get_booking — retrieve booking details by booking ID. - get_my_bookings — retrieve all confirmed bookings for the current user. +- add_service_to_booking — add an eligible service (bags, seat selection, etc.) to a specific flight within a booking. +- check_in — complete check-in for a specific flight in a booking. +- get_flight_timings — get check-in, boarding, and door-close timing windows for a flight. +- get_flight_status — get the latest status, gates, and delay information for a flight. ## Tool Use Guidelines: - Keep it tight: aim for 1-2 calls per turn (hard cap 4). From 6f780c2d2b246d162232b2798cd2529f07830b00 Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Tue, 4 Nov 2025 14:53:43 -0800 Subject: [PATCH 09/26] Remove uses of cast --- src/airline_agent/tools/booking.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 48769e9..47e2519 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -5,7 +5,7 @@ import uuid from datetime import UTC, date, datetime, timedelta from pathlib import Path -from typing import Any, cast +from typing import Any from airline_agent.types.booking import ( Booking, @@ -139,7 +139,7 @@ def get_fare_details(self, flight_id: str, fare_type: str = "basic") -> dict[str def book_flights( self, flight_ids: list[str], - fare_type: str = "basic", + fare_type: FareType = "basic", ) -> Booking: """ Book one or more flights for the current user. @@ -182,7 +182,7 @@ def book_flights( flight_bookings.append( FlightBooking( flight_id=flight_id, - fare_type=cast(FareType, fare_type), + fare_type=fare_type, base_price=fare.price_total, currency=fare.currency, included_services=fare.included_services.copy(), @@ -236,7 +236,7 @@ def add_service_to_booking( self, booking_id: str, flight_id: str, - service_type: str, + service_type: ServiceType, seat_preference: str | None = None, seat_assignment: str | None = None, ) -> Booking: @@ -325,7 +325,7 @@ def add_service_to_booking( flight_booking.add_ons.append( ServiceAddOn( - service_type=cast(ServiceType, service_type), + service_type=service_type, price=addon_option.price, currency=addon_option.currency, added_at=now, From ae3e102a41cf5130eb357c5c9dd6673b91d05bc6 Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 12:09:38 -0800 Subject: [PATCH 10/26] Standardize on not using PEP 563 annotations --- src/airline_agent/backend/schemas/message.py | 2 -- .../cleanlab_utils/conversion_utils.py | 13 +++---------- src/airline_agent/cleanlab_utils/validate_utils.py | 14 +++++--------- src/airline_agent/tools/booking.py | 2 -- 4 files changed, 8 insertions(+), 23 deletions(-) diff --git a/src/airline_agent/backend/schemas/message.py b/src/airline_agent/backend/schemas/message.py index 4e8ecbf..c11214d 100644 --- a/src/airline_agent/backend/schemas/message.py +++ b/src/airline_agent/backend/schemas/message.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from enum import StrEnum from typing import Literal diff --git a/src/airline_agent/cleanlab_utils/conversion_utils.py b/src/airline_agent/cleanlab_utils/conversion_utils.py index 586cfc0..d8a651f 100644 --- a/src/airline_agent/cleanlab_utils/conversion_utils.py +++ b/src/airline_agent/cleanlab_utils/conversion_utils.py @@ -1,18 +1,10 @@ """Convert pydantic-ai message history and responses to OpenAI Chat Completions format.""" -from __future__ import annotations - import base64 from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, Literal, cast - -from openai.types.chat import ChatCompletion - -if TYPE_CHECKING: - from openai.types.chat import ChatCompletionMessageParam - from pydantic_ai.tools import ToolDefinition +from typing import Any, Literal, cast -from openai.types.chat import ChatCompletionFunctionToolParam +from openai.types.chat import ChatCompletion, ChatCompletionFunctionToolParam, ChatCompletionMessageParam from pydantic_ai.messages import ( AudioUrl, BinaryContent, @@ -32,6 +24,7 @@ UserPromptPart, VideoUrl, ) +from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import RequestUsage diff --git a/src/airline_agent/cleanlab_utils/validate_utils.py b/src/airline_agent/cleanlab_utils/validate_utils.py index b4cb904..0a07312 100644 --- a/src/airline_agent/cleanlab_utils/validate_utils.py +++ b/src/airline_agent/cleanlab_utils/validate_utils.py @@ -1,25 +1,20 @@ -from __future__ import annotations - import logging import re import warnings -from typing import TYPE_CHECKING, Any, cast - -if TYPE_CHECKING: - from cleanlab_codex import Project - from codex.types.project_validate_response import ProjectValidateResponse - from pydantic_ai.agent import Agent, AgentRunResult - from pydantic_ai.tools import ToolDefinition +from typing import Any, cast +from cleanlab_codex import Project from cleanlab_tlm.utils.chat import _ASSISTANT_PREFIX as ASSISTANT_PREFIX from cleanlab_tlm.utils.chat import ( _form_prompt_chat_completions_api as form_prompt_chat_completions_api, ) +from codex.types.project_validate_response import ProjectValidateResponse from openai.types.chat import ( ChatCompletionAssistantMessageParam, ChatCompletionFunctionToolParam, ChatCompletionMessageParam, ) +from pydantic_ai.agent import Agent, AgentRunResult from pydantic_ai.messages import ( ModelMessage, ModelRequest, @@ -27,6 +22,7 @@ SystemPromptPart, UserPromptPart, ) +from pydantic_ai.tools import ToolDefinition from airline_agent.cleanlab_utils.conversion_utils import ( convert_message_to_chat_completion, diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 47e2519..24b95db 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import json import random import uuid From b110ef33ded9717e22daccedcb112522e4257105 Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 13:00:52 -0800 Subject: [PATCH 11/26] Switch to in-memory reservations This makes the agent stateless across backend runs. --- .../backend/services/airline_chat.py | 1 - src/airline_agent/tools/booking.py | 42 +++---------------- 2 files changed, 5 insertions(+), 38 deletions(-) diff --git a/src/airline_agent/backend/services/airline_chat.py b/src/airline_agent/backend/services/airline_chat.py index 8a14567..331f044 100644 --- a/src/airline_agent/backend/services/airline_chat.py +++ b/src/airline_agent/backend/services/airline_chat.py @@ -95,7 +95,6 @@ def get_cleanlab_project() -> Project: ) booking = BookingTools( flights_path=str(pathlib.Path(__file__).parents[4] / "data/flights.json"), - reservations_path=str(pathlib.Path(__file__).parents[4] / "data/reservations.json"), ) project = get_cleanlab_project() agent = create_agent(kb, booking) diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 24b95db..4b246cf 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -2,7 +2,6 @@ import random import uuid from datetime import UTC, date, datetime, timedelta -from pathlib import Path from typing import Any from airline_agent.types.booking import ( @@ -22,42 +21,14 @@ class BookingTools: - def __init__(self, flights_path: str, reservations_path: str | None = None): + def __init__(self, flights_path: str): self._flights_path = flights_path with open(flights_path) as f: raw = json.load(f) self._flights: dict[str, Flight] = {x["id"]: Flight(**x) for x in raw["flights"]} - # Initialize reservations storage - self._reservations_path = reservations_path + # Initialize reservations storage (in-memory only) self._reservations: dict[str, Booking] = {} - if reservations_path: - self._load_reservations() - - def _load_reservations(self) -> None: - """Load reservations from JSON file.""" - if not self._reservations_path: - return - try: - reservations_file = Path(self._reservations_path) - if reservations_file.exists(): - with open(reservations_file) as f: - data = json.load(f) - self._reservations = {bid: Booking(**booking_data) for bid, booking_data in data.items()} - else: - # Create empty file if it doesn't exist - reservations_file.parent.mkdir(parents=True, exist_ok=True) - self._save_reservations() - except (FileNotFoundError, json.JSONDecodeError): - self._reservations = {} - - def _save_reservations(self) -> None: - """Save reservations to JSON file.""" - if not self._reservations_path: - return - data = {bid: booking.model_dump(mode="json") for bid, booking in self._reservations.items()} - with open(self._reservations_path, "w") as f: - json.dump(data, f, indent=2, default=str) def _save_flights(self) -> None: """Save flights to JSON file.""" @@ -112,7 +83,7 @@ def get_fare_details(self, flight_id: str, fare_type: str = "basic") -> dict[str fare = next((f for f in flight.fares if f.fare_type == fare_type), None) if not fare: available_fares = [f.fare_type for f in flight.fares] - msg = f"Fare '{fare_type}' not available for flight {flight_id}. " f"Available fares: {available_fares}" + msg = f"Fare '{fare_type}' not available for flight {flight_id}. Available fares: {available_fares}" raise ValueError(msg) return { @@ -170,7 +141,7 @@ def book_flights( fare = next((f for f in flight.fares if f.fare_type == fare_type), None) if not fare: available_fares = [f.fare_type for f in flight.fares] - msg = f"Fare '{fare_type}' not available for flight {flight_id}. " f"Available fares: {available_fares}" + msg = f"Fare '{fare_type}' not available for flight {flight_id}. Available fares: {available_fares}" raise ValueError(msg) if fare.seats_available <= 0: @@ -202,7 +173,6 @@ def book_flights( ) self._reservations[booking_id] = booking - self._save_reservations() return booking @@ -336,9 +306,8 @@ def add_service_to_booking( # Update booking timestamp booking.status.updated_at = now - # Save the updated booking + # Save the updated booking in memory self._reservations[booking_id] = booking - self._save_reservations() return booking @@ -498,7 +467,6 @@ def check_in(self, booking_id: str, flight_id: str) -> Booking: # Save changes self._reservations[booking_id] = booking - self._save_reservations() self._save_flights() return booking From 37d0690c6d27c3212d8b33755e15eb9129d8cc85 Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 13:10:18 -0800 Subject: [PATCH 12/26] Factor out current time as a constant --- scripts/generate_flights.py | 13 +++++----- .../backend/services/airline_chat.py | 19 ++++++++++++--- .../cleanlab_utils/conversion_utils.py | 8 ++++--- src/airline_agent/constants.py | 15 ++++++------ src/airline_agent/tools/booking.py | 24 +++++-------------- 5 files changed, 42 insertions(+), 37 deletions(-) diff --git a/scripts/generate_flights.py b/scripts/generate_flights.py index ec9157b..8a16a07 100755 --- a/scripts/generate_flights.py +++ b/scripts/generate_flights.py @@ -2,7 +2,6 @@ """ Script to generate Frontier Airlines (F9) flight data for SF Bay Area to New York routes. Includes direct flights and connecting flights with layovers through hub airports. -Comprehensive coverage for Halloween week 2025 (Oct 31 - Nov 7). """ import json @@ -10,6 +9,8 @@ from datetime import UTC, datetime, timedelta from pathlib import Path +from airline_agent.constants import FLIGHT_DATA_DATE + # Constants SHORT_FLIGHT_THRESHOLD_HOURS = 2.0 # Threshold for short flights (hours) DURATION_ADJUSTMENT = 0.1 # Adjustment for SJC/OAK flights (hours) @@ -471,11 +472,10 @@ def main(): # Set random seed for reproducibility random.seed(42) - # Start date: Halloween 2025 (October 31) + 1 week - start_date = datetime(2025, 10, 31, tzinfo=UTC) - num_days = 8 # Oct 31 - Nov 7 + start_date = datetime.combine(FLIGHT_DATA_DATE, datetime.min.time(), tzinfo=UTC) + num_days = 8 - print("Generating comprehensive flight data for Halloween week 2025 (Oct 31 - Nov 7)...") + print(f"Generating comprehensive flight data starting from {FLIGHT_DATA_DATE.isoformat()}...") # Generate SF -> NYC flights (direct only) print("Generating direct flights (SF -> NYC)...") @@ -506,10 +506,11 @@ def main(): with open(flights_file, "w") as f: json.dump(output_data, f, indent=2) + end_date = start_date + timedelta(days=num_days - 1) print(f"\n✓ Successfully saved {len(all_flights)} total flights to {flights_file}") print(f" - Direct flights SF->NYC: {len(direct_flights_sf_to_nyc)}") print(f" - Direct flights NYC->SF: {len(direct_flights_nyc_to_sf)}") - print(" - All flights are DIRECT flights for Halloween week 2025 (Oct 31 - Nov 7)") + print(f" - All flights are DIRECT flights from {start_date.date().isoformat()} to {end_date.date().isoformat()}") print(" - No connecting/transfer flights included") diff --git a/src/airline_agent/backend/services/airline_chat.py b/src/airline_agent/backend/services/airline_chat.py index 331f044..65f773e 100644 --- a/src/airline_agent/backend/services/airline_chat.py +++ b/src/airline_agent/backend/services/airline_chat.py @@ -4,6 +4,7 @@ import pathlib import uuid from collections.abc import AsyncGenerator +from textwrap import dedent from cleanlab_codex import Client, Project from codex.types import ProjectValidateResponse @@ -46,7 +47,7 @@ get_tools_in_openai_format, run_cleanlab_validation_logging_tools, ) -from airline_agent.constants import AGENT_INSTRUCTIONS, AGENT_MODEL +from airline_agent.constants import AGENT_BASE_INSTRUCTIONS, AGENT_MODEL, DEMO_DATETIME from airline_agent.tools.booking import BookingTools from airline_agent.tools.knowledge_base import KnowledgeBase @@ -56,17 +57,29 @@ logger.setLevel(logging.INFO) +def instructions() -> str: + current_date_str = DEMO_DATETIME.date().isoformat() + current_datetime_str = DEMO_DATETIME.strftime("%Y-%m-%d %H:%M:%S %Z") + return dedent(f""" + {AGENT_BASE_INSTRUCTIONS} + + ## Context: + - Today's date: {current_date_str} + - Current time: {current_datetime_str} + """).strip() + + def create_agent(kb: KnowledgeBase, booking: BookingTools) -> Agent: """Create the airline support agent.""" model = OpenAIChatModel(model_name=AGENT_MODEL, settings=ModelSettings(temperature=0.0)) + return Agent( model=model, - instructions=AGENT_INSTRUCTIONS, + instructions=instructions, tools=[ kb.get_article, kb.search, kb.list_directory, - booking.get_current_date, booking.search_flights, booking.get_fare_details, booking.book_flights, diff --git a/src/airline_agent/cleanlab_utils/conversion_utils.py b/src/airline_agent/cleanlab_utils/conversion_utils.py index d8a651f..9ebe9a6 100644 --- a/src/airline_agent/cleanlab_utils/conversion_utils.py +++ b/src/airline_agent/cleanlab_utils/conversion_utils.py @@ -1,7 +1,7 @@ """Convert pydantic-ai message history and responses to OpenAI Chat Completions format.""" import base64 -from datetime import UTC, datetime +from datetime import datetime from typing import Any, Literal, cast from openai.types.chat import ChatCompletion, ChatCompletionFunctionToolParam, ChatCompletionMessageParam @@ -27,6 +27,8 @@ from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import RequestUsage +from airline_agent.constants import DEMO_DATETIME + def convert_to_openai_messages(message_history: list[ModelMessage]) -> list[ChatCompletionMessageParam]: """Convert pydantic-ai message history to OpenAI Chat Completions format.""" @@ -234,7 +236,7 @@ def convert_message_to_chat_completion(message: ChatCompletionMessageParam) -> C "message": choice_message, } ], - "created": int(datetime.now(UTC).timestamp()), + "created": int(DEMO_DATETIME.timestamp()), "model": "mock-agent", "object": "chat.completion", "service_tier": "default", @@ -277,7 +279,7 @@ def convert_string_to_response_message( model_name = None if timestamp is None: - timestamp = datetime.now(UTC) + timestamp = DEMO_DATETIME text_part = TextPart(content=content) usage = RequestUsage(input_tokens=0, output_tokens=0) return ModelResponse( diff --git a/src/airline_agent/constants.py b/src/airline_agent/constants.py index 5fdcae4..ead903b 100644 --- a/src/airline_agent/constants.py +++ b/src/airline_agent/constants.py @@ -1,7 +1,13 @@ import logging +from datetime import date, datetime, timedelta +from zoneinfo import ZoneInfo logger = logging.getLogger(__name__) +DEMO_DATE = date(2025, 11, 5) +DEMO_DATETIME = datetime(2025, 11, 5, 14, 0, 0, tzinfo=ZoneInfo("America/Los_Angeles")) +FLIGHT_DATA_DATE = DEMO_DATE + timedelta(days=7) + OFFICIAL_DEMO_PROJECT_ID = "3aae1f96-2dda-492f-8c86-17d453d3c298" # to copy configuration from STAGING_DEMO_PROJECT_ID = "6de236e4-c6e7-456c-b248-872236010992" RAG_EMBED_MODEL = "text-embedding-3-small" @@ -10,15 +16,13 @@ RAG_CHUNK_OVERLAP = 200 CONTEXT_RETRIEVAL_TOOLS = ["search", "get_article", "list_directory"] AGENT_MODEL = "gpt-4o" -AGENT_INSTRUCTIONS = ( - """You are an AI customer support agent for Frontier Airlines. You can use tools to access to a knowledge base of articles and -documents about the airline's services, policies, and procedures. +AGENT_BASE_INSTRUCTIONS = """ +You are an AI customer support agent for Frontier Airlines. You can use tools to access to a knowledge base of articles and documents about the airline's services, policies, and procedures. ## You have access to the following tools: - search — find candidate articles by query (keep top-k small, ≤5), returns title/snippet/path. - get_article — get the full article by its path. - list_directory — list directory structure to make more informed searches. -- get_current_date — return the current date (YYYY-MM-DD) and timestamp. - search_flights — search available flights by origin and destination airport codes (IATA) and departure date (YYYY-MM-DD). Always ask for the departure date if the user doesn't provide it. - get_fare_details — retrieve fare bundle pricing, included services, and add-ons for a specific flight. - book_flights — book one or more flights for the current user. Requires list of flight IDs and fare bundle type (basic, economy, premium, business; defaults to basic). Returns booking confirmation with booking ID and total price. @@ -44,8 +48,5 @@ - If you book flights, provide the booking ID and summarize the flights booked and total price. - If the user asks about anything unrelated to the airline, politely inform them that you can only assist with airline-related inquiries. """.strip() - .replace("\n", " ") - .replace(" ", " ") -) FALLBACK_RESPONSE = "I'm sorry, but I don't have the information you're looking for. Please rephrase the question or contact Frontier Airlines customer support for further assistance." diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 4b246cf..adc8bec 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -1,9 +1,10 @@ import json import random import uuid -from datetime import UTC, date, datetime, timedelta +from datetime import date, datetime, timedelta from typing import Any +from airline_agent.constants import DEMO_DATETIME from airline_agent.types.booking import ( Booking, BookingStatus, @@ -124,7 +125,7 @@ def book_flights( msg = "At least one flight ID must be provided" raise ValueError(msg) - now = datetime.now(UTC) + now = DEMO_DATETIME booking_id = f"BK-{uuid.uuid4().hex[:8].upper()}" flight_bookings: list[FlightBooking] = [] @@ -274,7 +275,7 @@ def add_service_to_booking( raise ValueError(msg) # Add the service add-on - now = datetime.now(UTC) + now = DEMO_DATETIME # For seat selection services, validate that preferences/assignments are only set for seat selection seat_selection_types = ["standard_seat_selection", "premium_seat_selection", "upfront_plus_seating"] @@ -311,19 +312,6 @@ def add_service_to_booking( return booking - def get_current_date(self) -> dict[str, str]: - """ - Get the current date and time. - - Returns: - Dictionary with current date in YYYY-MM-DD format and ISO timestamp - """ - now = datetime.now(UTC) - return { - "date": now.date().isoformat(), # YYYY-MM-DD format - "datetime": now.isoformat(), # Full ISO timestamp with timezone - } - def _assign_seat(self, flight_booking: FlightBooking, _flight_id: str) -> str: """Assign a seat to a flight booking based on preferences, fare type, or randomly.""" # Check if any seat selection add-on exists with an assignment @@ -449,7 +437,7 @@ def check_in(self, booking_id: str, flight_id: str) -> Booking: raise ValueError(msg) flight = self._flights[flight_id] - now = datetime.now(flight.departure.tzinfo) if flight.departure.tzinfo else datetime.now(UTC) + now = DEMO_DATETIME # Assign gates and terminals if needed self._assign_gates_and_terminals(flight) @@ -531,7 +519,7 @@ def get_flight_status(self, flight_id: str) -> dict[str, Any]: self._assign_gates_and_terminals(flight) # Update flight status based on current time - now = datetime.now(flight.departure.tzinfo) if flight.departure.tzinfo else datetime.now(UTC) + now = DEMO_DATETIME time_until_departure = flight.departure - now # Update status based on current time vs scheduled departure From c5b5f8714bd16f762d82b6b5f94d80f31fc30f2d Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 13:39:55 -0800 Subject: [PATCH 13/26] Switch to seeded random.Random instance --- scripts/generate_flights.py | 80 +++++++++++++++--------------- src/airline_agent/tools/booking.py | 30 ++++++----- 2 files changed, 59 insertions(+), 51 deletions(-) diff --git a/scripts/generate_flights.py b/scripts/generate_flights.py index 8a16a07..e39e561 100755 --- a/scripts/generate_flights.py +++ b/scripts/generate_flights.py @@ -181,44 +181,44 @@ def get_timezone_offset(airport: str) -> int: return HUB_TZ_OFFSETS.get(airport, -5) -def generate_fares() -> list[dict]: +def generate_fares(rng: random.Random) -> list[dict]: """Generate random fares for a flight with different fare bundles (Frontier Airlines model).""" fares = [] # Basic fare: no services included - basic_price = random.uniform(*FARE_BASE_PRICES["basic"]["price_range"]) # noqa: S311 + basic_price = rng.uniform(*FARE_BASE_PRICES["basic"]["price_range"]) fares.append( { "fare_type": "basic", "price_total": round(basic_price, 2), "currency": "USD", - "seats_available": random.randint(5, 15), # noqa: S311 + "seats_available": rng.randint(5, 15), "included_services": [], "checked_bags_included": 0, } ) # Economy bundle: Basic + Carry on, Standard seat selection, Refundability, Change/cancel fee waived - economy_price = random.uniform(*FARE_BASE_PRICES["economy"]["price_range"]) # noqa: S311 + economy_price = rng.uniform(*FARE_BASE_PRICES["economy"]["price_range"]) fares.append( { "fare_type": "economy", "price_total": round(economy_price, 2), "currency": "USD", - "seats_available": random.randint(3, 12), # noqa: S311 + "seats_available": rng.randint(3, 12), "included_services": ["carry_on", "standard_seat_selection", "refundability", "change_cancel_fee_waived"], "checked_bags_included": 0, } ) # Premium bundle: Economy + Premium seat selection + Priority Boarding - premium_price = random.uniform(*FARE_BASE_PRICES["premium"]["price_range"]) # noqa: S311 + premium_price = rng.uniform(*FARE_BASE_PRICES["premium"]["price_range"]) fares.append( { "fare_type": "premium", "price_total": round(premium_price, 2), "currency": "USD", - "seats_available": random.randint(2, 8), # noqa: S311 + "seats_available": rng.randint(2, 8), "included_services": [ "carry_on", "standard_seat_selection", @@ -232,13 +232,13 @@ def generate_fares() -> list[dict]: ) # Business bundle: Premium + 2 checked bags + UpFront Plus Seating - business_price = random.uniform(*FARE_BASE_PRICES["business"]["price_range"]) # noqa: S311 + business_price = rng.uniform(*FARE_BASE_PRICES["business"]["price_range"]) fares.append( { "fare_type": "business", "price_total": round(business_price, 2), "currency": "USD", - "seats_available": random.randint(1, 4), # noqa: S311 + "seats_available": rng.randint(1, 4), "included_services": [ "carry_on", "standard_seat_selection", @@ -255,60 +255,60 @@ def generate_fares() -> list[dict]: return fares -def generate_add_ons() -> list[dict]: +def generate_add_ons(rng: random.Random) -> list[dict]: """Generate available add-on services for a flight.""" return [ { "service_type": "checked_bag", - "price": round(random.uniform(30, 40), 2), # noqa: S311 + "price": round(rng.uniform(30, 40), 2), "currency": "USD", "description": "One checked bag (up to 50 lbs, 62 linear inches)", }, { "service_type": "carry_on", - "price": round(random.uniform(20, 30), 2), # noqa: S311 + "price": round(rng.uniform(20, 30), 2), "currency": "USD", "description": "One carry-on bag (personal item included)", }, { "service_type": "standard_seat_selection", - "price": round(random.uniform(10, 25), 2), # noqa: S311 + "price": round(rng.uniform(10, 25), 2), "currency": "USD", "description": "Select a standard seat in advance", }, { "service_type": "premium_seat_selection", - "price": round(random.uniform(25, 45), 2), # noqa: S311 + "price": round(rng.uniform(25, 45), 2), "currency": "USD", "description": "Select a stretch seat with extra legroom", }, { "service_type": "upfront_plus_seating", - "price": round(random.uniform(50, 100), 2), # noqa: S311 + "price": round(rng.uniform(50, 100), 2), "currency": "USD", "description": "UpFront Plus seating in first two rows with guaranteed empty middle seat", }, { "service_type": "priority_boarding", - "price": round(random.uniform(8, 15), 2), # noqa: S311 + "price": round(rng.uniform(8, 15), 2), "currency": "USD", "description": "Priority boarding with overhead bin space", }, { "service_type": "travel_insurance", - "price": round(random.uniform(15, 30), 2), # noqa: S311 + "price": round(rng.uniform(15, 30), 2), "currency": "USD", "description": "Trip protection insurance", }, { "service_type": "refundability", - "price": round(random.uniform(30, 60), 2), # noqa: S311 + "price": round(rng.uniform(30, 60), 2), "currency": "USD", "description": "Add refundability to your booking", }, { "service_type": "change_cancel_fee_waived", - "price": round(random.uniform(20, 40), 2), # noqa: S311 + "price": round(rng.uniform(20, 40), 2), "currency": "USD", "description": "Waive change and cancel fees", }, @@ -322,6 +322,7 @@ def generate_flight_id(origin: str, destination: str, departure: datetime, carri def generate_direct_flights( + rng: random.Random, start_date: datetime, num_days: int = 8, origin_airports: list[str] | None = None, @@ -342,12 +343,12 @@ def generate_direct_flights( for origin in origin_airports: for destination in dest_airports: # Generate 3-6 flights per origin-destination pair per day - num_flights = random.randint(3, 6) # noqa: S311 + num_flights = rng.randint(3, 6) for _ in range(num_flights): # Random departure time between 6 AM and 10 PM - hour = random.randint(6, 22) # noqa: S311 - minute = random.choice([0, 15, 30, 45]) # noqa: S311 + hour = rng.randint(6, 22) + minute = rng.choice([0, 15, 30, 45]) carrier_code = CARRIER_CODE @@ -370,10 +371,10 @@ def generate_direct_flights( "destination": destination, "departure": departure_str, "arrival": arrival_str, - "flight_number": f"{carrier_code} {random.randint(100, 999)}", # noqa: S311 + "flight_number": f"{carrier_code} {rng.randint(100, 999)}", "carrier": carrier_code, - "fares": generate_fares(), - "add_ons": generate_add_ons(), + "fares": generate_fares(rng), + "add_ons": generate_add_ons(rng), } flights.append(flight) @@ -382,6 +383,7 @@ def generate_direct_flights( def generate_connecting_flights( + rng: random.Random, start_date: datetime, num_days: int = 8, origin_airports: list[str] | None = None, @@ -405,21 +407,21 @@ def generate_connecting_flights( # Use all hubs to create many transfer options for hub in HUB_AIRPORTS: # Generate 1-3 connecting flights per hub per origin-destination pair per day - num_routes = random.randint(1, 3) # noqa: S311 + num_routes = rng.randint(1, 3) for _ in range(num_routes): carrier_code = CARRIER_CODE # First leg: Origin -> Hub - hour1 = random.randint(6, 18) # noqa: S311 - minute1 = random.choice([0, 15, 30, 45]) # noqa: S311 + hour1 = rng.randint(6, 18) + minute1 = rng.choice([0, 15, 30, 45]) departure_time_leg1 = date.replace(hour=hour1, minute=minute1, second=0, microsecond=0) duration1 = get_flight_duration(origin, hub) arrival_time_leg1 = departure_time_leg1 + timedelta(hours=duration1) # Layover: 45 minutes to 3 hours - layover_hours = random.choice([0.75, 1.0, 1.5, 2.0, 2.5, 3.0]) # noqa: S311 + layover_hours = rng.choice([0.75, 1.0, 1.5, 2.0, 2.5, 3.0]) departure_time_leg2 = arrival_time_leg1 + timedelta(hours=layover_hours) # Second leg: Hub -> Destination @@ -438,10 +440,10 @@ def generate_connecting_flights( f"%Y-%m-%dT%H:%M:00{departure_offset_leg1:+03d}:00" ), "arrival": arrival_time_leg1.strftime(f"%Y-%m-%dT%H:%M:00{arrival_offset_leg1:+03d}:00"), - "flight_number": f"{carrier_code} {random.randint(100, 999)}", # noqa: S311 + "flight_number": f"{carrier_code} {rng.randint(100, 999)}", "carrier": carrier_code, - "fares": generate_fares(), - "add_ons": generate_add_ons(), + "fares": generate_fares(rng), + "add_ons": generate_add_ons(rng), } # Second leg @@ -456,10 +458,10 @@ def generate_connecting_flights( f"%Y-%m-%dT%H:%M:00{departure_offset_leg2:+03d}:00" ), "arrival": arrival_time_leg2.strftime(f"%Y-%m-%dT%H:%M:00{arrival_offset_leg2:+03d}:00"), - "flight_number": f"{carrier_code} {random.randint(100, 999)}", # noqa: S311 + "flight_number": f"{carrier_code} {rng.randint(100, 999)}", "carrier": carrier_code, - "fares": generate_fares(), - "add_ons": generate_add_ons(), + "fares": generate_fares(rng), + "add_ons": generate_add_ons(rng), } flights.extend([flight1, flight2]) @@ -469,8 +471,8 @@ def generate_connecting_flights( def main(): """Main function to generate and save flight data.""" - # Set random seed for reproducibility - random.seed(42) + # Create seeded random number generator for reproducibility + rng = random.Random(42) # noqa: S311 start_date = datetime.combine(FLIGHT_DATA_DATE, datetime.min.time(), tzinfo=UTC) num_days = 8 @@ -480,14 +482,14 @@ def main(): # Generate SF -> NYC flights (direct only) print("Generating direct flights (SF -> NYC)...") direct_flights_sf_to_nyc = generate_direct_flights( - start_date, num_days=num_days, origin_airports=SF_AIRPORTS, dest_airports=NYC_AIRPORTS + rng, start_date, num_days=num_days, origin_airports=SF_AIRPORTS, dest_airports=NYC_AIRPORTS ) print(f"Generated {len(direct_flights_sf_to_nyc)} direct flights from SF to NYC") # Generate NYC -> SF flights (direct only) print("Generating direct flights (NYC -> SF)...") direct_flights_nyc_to_sf = generate_direct_flights( - start_date, num_days=num_days, origin_airports=NYC_AIRPORTS, dest_airports=SF_AIRPORTS + rng, start_date, num_days=num_days, origin_airports=NYC_AIRPORTS, dest_airports=SF_AIRPORTS ) print(f"Generated {len(direct_flights_nyc_to_sf)} direct flights from NYC to SF") diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index adc8bec..3095cbf 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -1,6 +1,5 @@ import json import random -import uuid from datetime import date, datetime, timedelta from typing import Any @@ -20,6 +19,9 @@ BOARDING_START_THRESHOLD = 900 # 15 minutes until departure ON_TIME_THRESHOLD = 1800 # 30 minutes until departure +# Seed for deterministic random number generation +RNG_SEED = 42 + class BookingTools: def __init__(self, flights_path: str): @@ -31,6 +33,9 @@ def __init__(self, flights_path: str): # Initialize reservations storage (in-memory only) self._reservations: dict[str, Booking] = {} + # Initialize seeded random number generator for deterministic behavior + self._rng = random.Random(RNG_SEED) # noqa: S311 + def _save_flights(self) -> None: """Save flights to JSON file.""" data = {"flights": [flight.model_dump(mode="json") for flight in self._flights.values()]} @@ -126,7 +131,8 @@ def book_flights( raise ValueError(msg) now = DEMO_DATETIME - booking_id = f"BK-{uuid.uuid4().hex[:8].upper()}" + # Generate deterministic booking ID using seeded random + booking_id = f"BK-{self._rng.randint(0, 0xFFFFFFFF):08X}" flight_bookings: list[FlightBooking] = [] currency = "USD" @@ -345,7 +351,7 @@ def _assign_seat(self, flight_booking: FlightBooking, _flight_id: str) -> str: # Default to standard seating area row_min, row_max = (16, 40) - row = random.randint(row_min, row_max) # noqa: S311 + row = self._rng.randint(row_min, row_max) # Seat letters (typical 3-3 configuration: A, B, C, D, E, F) seat_letters = ["A", "B", "C", "D", "E", "F"] @@ -353,11 +359,11 @@ def _assign_seat(self, flight_booking: FlightBooking, _flight_id: str) -> str: aisle_seats = ["C", "D"] if preference == "window": - seat_letter = random.choice(window_seats) # noqa: S311 + seat_letter = self._rng.choice(window_seats) elif preference == "aisle": - seat_letter = random.choice(aisle_seats) # noqa: S311 + seat_letter = self._rng.choice(aisle_seats) else: - seat_letter = random.choice(seat_letters) # noqa: S311 + seat_letter = self._rng.choice(seat_letters) return f"{row}{seat_letter}" @@ -366,24 +372,24 @@ def _assign_gates_and_terminals(self, flight: Flight) -> None: # Assign departure terminal and gate if not already assigned if not flight.departure_terminal: terminals = ["Terminal 1", "Terminal 2", "Terminal 3", "Terminal A", "Terminal B"] - flight.departure_terminal = random.choice(terminals) # noqa: S311 + flight.departure_terminal = self._rng.choice(terminals) if not flight.departure_gate: # Generate a gate like "A15", "B22", "C8" gate_letters = ["A", "B", "C", "D"] - gate_letter = random.choice(gate_letters) # noqa: S311 - gate_number = random.randint(1, 50) # noqa: S311 + gate_letter = self._rng.choice(gate_letters) + gate_number = self._rng.randint(1, 50) flight.departure_gate = f"{gate_letter}{gate_number}" # Assign arrival terminal and gate if not already assigned if not flight.arrival_terminal: terminals = ["Terminal 1", "Terminal 2", "Terminal 3", "Terminal A", "Terminal B"] - flight.arrival_terminal = random.choice(terminals) # noqa: S311 + flight.arrival_terminal = self._rng.choice(terminals) if not flight.arrival_gate: gate_letters = ["A", "B", "C", "D", "E"] - gate_letter = random.choice(gate_letters) # noqa: S311 - gate_number = random.randint(1, 60) # noqa: S311 + gate_letter = self._rng.choice(gate_letters) + gate_number = self._rng.randint(1, 60) flight.arrival_gate = f"{gate_letter}{gate_number}" def _calculate_check_in_timings(self, departure: datetime) -> dict[str, datetime]: From 9062fb0dcaf04a946f726bc88b69d7336f7f0ddb Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 13:50:24 -0800 Subject: [PATCH 14/26] Make booking tool not mutate on-disk data --- src/airline_agent/tools/booking.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 3095cbf..9a86f8d 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -36,12 +36,6 @@ def __init__(self, flights_path: str): # Initialize seeded random number generator for deterministic behavior self._rng = random.Random(RNG_SEED) # noqa: S311 - def _save_flights(self) -> None: - """Save flights to JSON file.""" - data = {"flights": [flight.model_dump(mode="json") for flight in self._flights.values()]} - with open(self._flights_path, "w") as f: - json.dump(data, f, indent=2, default=str) - def search_flights(self, origin: str, destination: str, departure_date: str) -> list[Flight]: """ Search available flights by route and date. @@ -461,7 +455,6 @@ def check_in(self, booking_id: str, flight_id: str) -> Booking: # Save changes self._reservations[booking_id] = booking - self._save_flights() return booking From 3d307e2cfc6f5ab620d9f8fe60529ce1b289c044 Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 14:15:12 -0800 Subject: [PATCH 15/26] Use model_validate --- src/airline_agent/preprocessing/create_vector_database.py | 2 +- src/airline_agent/tools/booking.py | 2 +- src/airline_agent/tools/knowledge_base.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/airline_agent/preprocessing/create_vector_database.py b/src/airline_agent/preprocessing/create_vector_database.py index 2e25a10..2b16478 100644 --- a/src/airline_agent/preprocessing/create_vector_database.py +++ b/src/airline_agent/preprocessing/create_vector_database.py @@ -25,7 +25,7 @@ def main() -> None: verify_checksum(data_path) with data_path.open() as f: - data: list[KBArticle] = [KBArticle(**entry) for entry in json.load(f)] + data: list[KBArticle] = [KBArticle.model_validate(entry) for entry in json.load(f)] documents = to_documents(data) diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 9a86f8d..5f781d0 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -28,7 +28,7 @@ def __init__(self, flights_path: str): self._flights_path = flights_path with open(flights_path) as f: raw = json.load(f) - self._flights: dict[str, Flight] = {x["id"]: Flight(**x) for x in raw["flights"]} + self._flights: dict[str, Flight] = {x["id"]: Flight.model_validate(x) for x in raw["flights"]} # Initialize reservations storage (in-memory only) self._reservations: dict[str, Booking] = {} diff --git a/src/airline_agent/tools/knowledge_base.py b/src/airline_agent/tools/knowledge_base.py index 7123179..e1988cb 100644 --- a/src/airline_agent/tools/knowledge_base.py +++ b/src/airline_agent/tools/knowledge_base.py @@ -14,7 +14,7 @@ class KnowledgeBase: def __init__(self, kb_path: str, vector_index_path: str): with open(kb_path) as f: - kb_entries: list[KBArticle] = [KBArticle(**article) for article in json.load(f)] + kb_entries: list[KBArticle] = [KBArticle.model_validate(article) for article in json.load(f)] self._kb: dict[str, KBArticle] = {article.path: article for article in kb_entries} storage_context = StorageContext.from_defaults(persist_dir=vector_index_path) From 9ef5270b3684c9658cb9ac36c67a056138f26730 Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 14:21:02 -0800 Subject: [PATCH 16/26] Add tests --- src/airline_agent/tools/booking.py | 5 + tests/conftest.py | 7 + tests/test_booking.py | 202 +++++++++++++++++++++++++++++ 3 files changed, 214 insertions(+) create mode 100644 tests/test_booking.py diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 5f781d0..b0bca38 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -36,6 +36,11 @@ def __init__(self, flights_path: str): # Initialize seeded random number generator for deterministic behavior self._rng = random.Random(RNG_SEED) # noqa: S311 + def _reset(self) -> None: + """Clear all reservations and reset the random number generator for test isolation.""" + self._reservations = {} + self._rng = random.Random(RNG_SEED) # noqa: S311 + def search_flights(self, origin: str, destination: str, departure_date: str) -> list[Flight]: """ Search available flights by route and date. diff --git a/tests/conftest.py b/tests/conftest.py index a1520c6..cf34546 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from cleanlab_codex import Client from dotenv import load_dotenv +from airline_agent.backend.services.airline_chat import booking from airline_agent.constants import OFFICIAL_DEMO_PROJECT_ID from tests.util import Project @@ -15,6 +16,12 @@ def load_env() -> None: load_dotenv() +@pytest.fixture(autouse=True) +def clear_bookings() -> None: + """Clear all bookings before each test to ensure test isolation.""" + booking._reset() # noqa: SLF001 + + @pytest.fixture(scope="session") def codex_client(load_env: None) -> Client: # noqa: ARG001 codex_api_key = os.getenv("CODEX_API_KEY") diff --git a/tests/test_booking.py b/tests/test_booking.py new file mode 100644 index 0000000..29113cf --- /dev/null +++ b/tests/test_booking.py @@ -0,0 +1,202 @@ +from tests.judge import assert_judge +from tests.util import Agent + + +def test_search_flights() -> None: + """Test searching for flights between SF and NYC.""" + agent = Agent(cleanlab_enabled=False) + answer, _ = agent.chat("I want to fly from SFO to JFK on November 12, 2025") + assert_judge( + [ + "output mentions or lists available flights from SFO to JFK", + "output mentions departure times or flight options", + ], + answer, + ) + + +def test_get_fare_details() -> None: + """Test getting fare bundle details for a flight.""" + agent = Agent(cleanlab_enabled=False) + # First search for flights + agent.chat("Show me flights from SFO to EWR on November 12, 2025") + # Then ask about fare details + answer, _ = agent.chat("What's included in the economy fare bundle for the first flight?") + assert_judge( + [ + "output mentions details about the economy fare bundle", + "output mentions included services or benefits", + ], + answer, + ) + + +def test_book_single_flight() -> None: + """Test booking a single flight.""" + agent = Agent(cleanlab_enabled=False) + # Search and book + agent.chat("I need a flight from SFO to JFK on November 12, 2025") + answer, _ = agent.chat("Book the first available flight with basic fare") + assert_judge( + [ + "output confirms a booking was made", + "output mentions a booking ID or confirmation number", + "output mentions the total price", + ], + answer, + ) + + +def test_book_round_trip() -> None: + """Test booking a round trip (outbound and return flights).""" + agent = Agent(cleanlab_enabled=False) + # Search for outbound + agent.chat("Find flights from OAK to LGA on November 13, 2025") + # Search for return + agent.chat("Find return flights from LGA to OAK on November 15, 2025") + # Book both + answer, _ = agent.chat("Book the first flight for each leg with economy fare") + assert_judge( + [ + "output confirms booking of both flights (outbound and return)", + "output mentions a booking ID", + "output mentions the total price for both flights", + ], + answer, + ) + + +def test_retrieve_booking() -> None: + """Test retrieving booking details by booking ID.""" + agent = Agent(cleanlab_enabled=False) + # Create a booking first + agent.chat("Find a flight from SJC to JFK on November 12, 2025") + booking_response, _ = agent.chat("Book the first flight with basic fare") + # Retrieve bookings + answer, _ = agent.chat("Show me my bookings") + assert_judge( + [ + "output shows booking information", + "output mentions flight details or booking status", + ], + answer, + ) + + +def test_add_service_to_booking() -> None: + """Test adding a service (checked bag) to an existing booking.""" + agent = Agent(cleanlab_enabled=False) + # Create a booking + agent.chat("Show me flights from SFO to EWR on November 14, 2025") + agent.chat("Book the first flight with basic fare") + # Add a service + answer, _ = agent.chat("Add a checked bag to my booking") + assert_judge( + [ + "output confirms the checked bag was added", + "output mentions the additional cost or updated price", + ], + answer, + ) + + +def test_check_in() -> None: + """Test checking in for a flight.""" + agent = Agent(cleanlab_enabled=False) + # Create a booking + agent.chat("Find flights from SFO to JFK on November 12, 2025") + agent.chat("Book the first available flight") + # Check in + answer, _ = agent.chat("Check me in for my flight") + assert_judge( + [ + "output confirms check-in was successful", + "output mentions a seat assignment or boarding information", + ], + answer, + ) + + +def test_flight_status() -> None: + """Test getting flight status information.""" + agent = Agent(cleanlab_enabled=False) + # Search for a flight first to get context + agent.chat("Show me flights from OAK to LGA on November 12, 2025") + # Ask for status + answer, _ = agent.chat("What's the status of the first flight?") + assert_judge( + [ + "output provides flight status information", + "output mentions status (on time, delayed, boarding, etc.) or gate information", + ], + answer, + ) + + +def test_flight_timings() -> None: + """Test getting flight timing windows (check-in, boarding, etc.).""" + agent = Agent(cleanlab_enabled=False) + # Search for a flight + agent.chat("Find flights from SJC to EWR on November 13, 2025") + # Ask about timings + answer, _ = agent.chat("When does check-in open for the first flight?") + assert_judge( + [ + "output provides timing information", + "output mentions check-in times or boarding times", + ], + answer, + ) + + +def test_fare_comparison() -> None: + """Test comparing different fare bundles.""" + agent = Agent(cleanlab_enabled=False) + agent.chat("Show me flights from SFO to JFK on November 12, 2025") + answer, _ = agent.chat("What's the difference between basic and premium fare for the first flight?") + assert_judge( + [ + "output compares basic and premium fares", + "output mentions differences in price or included services", + ], + answer, + ) + + +def test_invalid_route() -> None: + """Test handling of invalid or unavailable routes.""" + agent = Agent(cleanlab_enabled=False) + answer, _ = agent.chat("Find flights from SFO to Tokyo on November 12, 2025") + assert_judge( + [ + "output indicates no flights are available or the route is not served", + "output does NOT show flights from SFO to Tokyo", + ], + answer, + ) + + +def test_no_date_provided() -> None: + """Test that agent asks for date when searching flights without one.""" + agent = Agent(cleanlab_enabled=False) + answer, _ = agent.chat("I want to fly from SFO to JFK") + assert_judge( + [ + "output asks for the departure date", + "output does NOT show a list of flights (because date is missing)", + ], + answer, + ) + + +def test_no_existing_bookings() -> None: + """Test isolation: verify no bookings exist from previous tests.""" + agent = Agent(cleanlab_enabled=False) + answer, _ = agent.chat("Show me my bookings") + assert_judge( + [ + "output indicates there are no bookings or no confirmed bookings", + "output does NOT list any specific flight bookings", + ], + answer, + ) From f8533e2289fd28a3f2ea8c89f8aeef4d1dce9017 Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 15:00:34 -0800 Subject: [PATCH 17/26] Use better typing --- src/airline_agent/tools/booking.py | 75 ++++++++++++++++++------------ src/airline_agent/types/booking.py | 43 +++++++++++++---- 2 files changed, 79 insertions(+), 39 deletions(-) diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index b0bca38..3beb25d 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -10,7 +10,9 @@ FareType, Flight, FlightBooking, - ServiceAddOn, + GenericServiceAddOn, + SeatServiceAddOn, + SeatType, ServiceType, ) @@ -282,32 +284,42 @@ def add_service_to_booking( # Add the service add-on now = DEMO_DATETIME - # For seat selection services, validate that preferences/assignments are only set for seat selection - seat_selection_types = ["standard_seat_selection", "premium_seat_selection", "upfront_plus_seating"] - if service_type not in seat_selection_types and (seat_preference or seat_assignment): - msg = "seat_preference and seat_assignment can only be set for seat selection service types" - raise ValueError(msg) + # Create appropriate add-on type based on service type + addon: SeatServiceAddOn | GenericServiceAddOn + match service_type: + case "standard_seat_selection" | "premium_seat_selection" | "upfront_plus_seating": + seat_type: SeatType + match service_type: + case "standard_seat_selection": + seat_type = "standard" + case "premium_seat_selection": + seat_type = "stretch" + case "upfront_plus_seating": + seat_type = "upfront_plus" + + addon = SeatServiceAddOn( + service_type=service_type, + price=addon_option.price, + currency=addon_option.currency, + added_at=now, + seat_preference=seat_preference, + seat_assignment=seat_assignment, + seat_type=seat_type, + ) + case _: + # For non-seat services, validate that seat parameters weren't provided + if seat_preference or seat_assignment: + msg = "seat_preference and seat_assignment can only be set for seat selection service types" + raise ValueError(msg) + + addon = GenericServiceAddOn( + service_type=service_type, + price=addon_option.price, + currency=addon_option.currency, + added_at=now, + ) - # Determine seat type based on service type - seat_type = None - if service_type == "standard_seat_selection": - seat_type = "standard" - elif service_type == "premium_seat_selection": - seat_type = "stretch" - elif service_type == "upfront_plus_seating": - seat_type = "upfront_plus" - - flight_booking.add_ons.append( - ServiceAddOn( - service_type=service_type, - price=addon_option.price, - currency=addon_option.currency, - added_at=now, - seat_preference=seat_preference, - seat_assignment=seat_assignment, - seat_type=seat_type, - ) - ) + flight_booking.add_ons.append(addon) # Update booking timestamp booking.status.updated_at = now @@ -320,11 +332,12 @@ def add_service_to_booking( def _assign_seat(self, flight_booking: FlightBooking, _flight_id: str) -> str: """Assign a seat to a flight booking based on preferences, fare type, or randomly.""" # Check if any seat selection add-on exists with an assignment - seat_selection_types = ["standard_seat_selection", "premium_seat_selection", "upfront_plus_seating"] - seat_addon = next( - (addon for addon in flight_booking.add_ons if addon.service_type in seat_selection_types), - None, - ) + seat_addon: SeatServiceAddOn | None = None + for addon in flight_booking.add_ons: + if isinstance(addon, SeatServiceAddOn): + seat_addon = addon + break + if seat_addon and seat_addon.seat_assignment: return seat_addon.seat_assignment diff --git a/src/airline_agent/types/booking.py b/src/airline_agent/types/booking.py index 7290fbd..f6ae114 100644 --- a/src/airline_agent/types/booking.py +++ b/src/airline_agent/types/booking.py @@ -1,7 +1,7 @@ from datetime import datetime -from typing import Literal +from typing import Annotated, Literal -from pydantic import BaseModel, Field, computed_field +from pydantic import BaseModel, Discriminator, Field, computed_field FareType = Literal["basic", "economy", "premium", "business"] ServiceType = Literal[ @@ -15,6 +15,16 @@ "refundability", "change_cancel_fee_waived", ] +SeatServiceType = Literal["standard_seat_selection", "premium_seat_selection", "upfront_plus_seating"] +GenericServiceType = Literal[ + "checked_bag", + "carry_on", + "priority_boarding", + "travel_insurance", + "refundability", + "change_cancel_fee_waived", +] +SeatType = Literal["standard", "stretch", "upfront_plus"] FlightStatus = Literal[ "scheduled", "on_time", @@ -26,13 +36,18 @@ ] -class ServiceAddOn(BaseModel): - """A service add-on purchased for a flight.""" +class ServiceAddOnBase(BaseModel): + """Base class for service add-ons with common fields.""" - service_type: ServiceType price: float currency: str = "USD" added_at: datetime = Field(..., description="Timestamp when the add-on was added") + + +class SeatServiceAddOn(ServiceAddOnBase): + """Seat selection service add-on with seat-specific fields.""" + + service_type: SeatServiceType seat_preference: str | None = Field( default=None, description='Seat preference such as "window", "aisle", or "middle"', @@ -41,12 +56,24 @@ class ServiceAddOn(BaseModel): default=None, description='Assigned seat identifier, for example "12A" or "15F"', ) - seat_type: str | None = Field( - default=None, - description=('Type of seat selected, for example "standard", "stretch", or ' '"upfront_plus"'), + seat_type: SeatType = Field( + ..., + description='Type of seat selected: "standard", "stretch", or "upfront_plus"', ) +class GenericServiceAddOn(ServiceAddOnBase): + """Non-seat service add-on (bags, insurance, priority boarding, etc.).""" + + service_type: GenericServiceType + + +ServiceAddOn = Annotated[ + SeatServiceAddOn | GenericServiceAddOn, + Discriminator("service_type"), +] + + class Fare(BaseModel): """Frontier Airlines fare bundle. No separate cabin classes - all passengers in same cabin.""" From f35a5a8bb83cc1fcf693c759dac821b2aeaa8e6f Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 15:28:42 -0800 Subject: [PATCH 18/26] Switch to generating flights at agent boot time --- README.md | 6 +- pyproject.toml | 1 - .../backend/services/airline_chat.py | 4 +- src/airline_agent/constants.py | 1 + src/airline_agent/data_generation/__init__.py | 0 .../data_generation}/generate_flights.py | 395 +++++++++--------- src/airline_agent/tools/booking.py | 11 +- 7 files changed, 197 insertions(+), 221 deletions(-) create mode 100644 src/airline_agent/data_generation/__init__.py rename {scripts => src/airline_agent/data_generation}/generate_flights.py (53%) mode change 100755 => 100644 diff --git a/README.md b/README.md index a9b6ef6..890da5c 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,7 @@ This repository contains an example of a production-grade AI Agent that is integ ```console $ hatch run create-vector-database ``` -7. Generate flight schedule data: - ```console - $ hatch run generate-flights - ``` -8. Install frontend dependencies: +7. Install frontend dependencies: ```console $ cd frontend $ npm install diff --git a/pyproject.toml b/pyproject.toml index e661ead..1a3c022 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ create-demo-project = "python -m airline_agent.cleanlab_utils.create_new_demo_pr open-project = "python -m airline_agent.cleanlab_utils.open_project {args}" fetch-pages = "python -m airline_agent.data_preparation.fetch_pages {args}" create-vector-database = "python -m airline_agent.preprocessing.create_vector_database {args}" -generate-flights = "python scripts/generate_flights.py {args}" backend-server = "python -m airline_agent.backend.app {args}" diff --git a/src/airline_agent/backend/services/airline_chat.py b/src/airline_agent/backend/services/airline_chat.py index 65f773e..166f741 100644 --- a/src/airline_agent/backend/services/airline_chat.py +++ b/src/airline_agent/backend/services/airline_chat.py @@ -106,9 +106,7 @@ def get_cleanlab_project() -> Project: kb_path=str(pathlib.Path(__file__).parents[4] / "data/kb.json"), vector_index_path=str(pathlib.Path(__file__).parents[4] / "data/vector-db"), ) -booking = BookingTools( - flights_path=str(pathlib.Path(__file__).parents[4] / "data/flights.json"), -) +booking = BookingTools() project = get_cleanlab_project() agent = create_agent(kb, booking) diff --git a/src/airline_agent/constants.py b/src/airline_agent/constants.py index ead903b..3577861 100644 --- a/src/airline_agent/constants.py +++ b/src/airline_agent/constants.py @@ -7,6 +7,7 @@ DEMO_DATE = date(2025, 11, 5) DEMO_DATETIME = datetime(2025, 11, 5, 14, 0, 0, tzinfo=ZoneInfo("America/Los_Angeles")) FLIGHT_DATA_DATE = DEMO_DATE + timedelta(days=7) +FLIGHT_DATA_NUM_DAYS = 8 OFFICIAL_DEMO_PROJECT_ID = "3aae1f96-2dda-492f-8c86-17d453d3c298" # to copy configuration from STAGING_DEMO_PROJECT_ID = "6de236e4-c6e7-456c-b248-872236010992" diff --git a/src/airline_agent/data_generation/__init__.py b/src/airline_agent/data_generation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/generate_flights.py b/src/airline_agent/data_generation/generate_flights.py old mode 100755 new mode 100644 similarity index 53% rename from scripts/generate_flights.py rename to src/airline_agent/data_generation/generate_flights.py index e39e561..f2065a0 --- a/scripts/generate_flights.py +++ b/src/airline_agent/data_generation/generate_flights.py @@ -1,17 +1,17 @@ -#!/usr/bin/env python3 """ -Script to generate Frontier Airlines (F9) flight data for SF Bay Area to New York routes. +Module to generate Frontier Airlines (F9) flight data for SF Bay Area to New York routes. Includes direct flights and connecting flights with layovers through hub airports. """ -import json import random from datetime import UTC, datetime, timedelta -from pathlib import Path +from zoneinfo import ZoneInfo -from airline_agent.constants import FLIGHT_DATA_DATE +from airline_agent.constants import FLIGHT_DATA_DATE, FLIGHT_DATA_NUM_DAYS +from airline_agent.types.booking import Fare, Flight, ServiceAddOnOption # Constants +RNG_SEED = 42 SHORT_FLIGHT_THRESHOLD_HOURS = 2.0 # Threshold for short flights (hours) DURATION_ADJUSTMENT = 0.1 # Adjustment for SJC/OAK flights (hours) SF_BAY_AIRPORTS = {"SJC", "OAK"} # SF Bay Area airports (excluding SFO) @@ -147,22 +147,27 @@ if sf in SF_AIRPORTS and hub in HUB_AIRPORTS: FLIGHT_DURATIONS[(hub, sf)] = duration -# Timezone offsets (PST/PDT for SF airports, EST/EDT for NYC) -SF_TZ_OFFSET = -8 # PST -NYC_TZ_OFFSET = -5 # EST - -# Hub airport timezones -HUB_TZ_OFFSETS = { - "DEN": -7, # MST - "ORD": -6, # CST - "ATL": -5, # EST - "DFW": -6, # CST - "LAS": -8, # PST - "PHX": -7, # MST - "SEA": -8, # PST - "IAH": -6, # CST - "MSP": -6, # CST - "DTW": -5, # EST +# Timezone mappings for airports +AIRPORT_TIMEZONES = { + # SF Bay Area + "SFO": ZoneInfo("America/Los_Angeles"), + "SJC": ZoneInfo("America/Los_Angeles"), + "OAK": ZoneInfo("America/Los_Angeles"), + # NYC + "JFK": ZoneInfo("America/New_York"), + "EWR": ZoneInfo("America/New_York"), + "LGA": ZoneInfo("America/New_York"), + # Hubs + "DEN": ZoneInfo("America/Denver"), + "ORD": ZoneInfo("America/Chicago"), + "ATL": ZoneInfo("America/New_York"), + "DFW": ZoneInfo("America/Chicago"), + "LAS": ZoneInfo("America/Los_Angeles"), + "PHX": ZoneInfo("America/Phoenix"), + "SEA": ZoneInfo("America/Los_Angeles"), + "IAH": ZoneInfo("America/Chicago"), + "MSP": ZoneInfo("America/Chicago"), + "DTW": ZoneInfo("America/Detroit"), } @@ -172,54 +177,53 @@ def get_flight_duration(origin: str, destination: str) -> float: return FLIGHT_DURATIONS.get(route, 3.0) # Default 3 hours if not found -def get_timezone_offset(airport: str) -> int: - """Get timezone offset for an airport.""" - if airport in SF_AIRPORTS: - return SF_TZ_OFFSET - if airport in NYC_AIRPORTS: - return NYC_TZ_OFFSET - return HUB_TZ_OFFSETS.get(airport, -5) +def get_airport_timezone(airport: str) -> ZoneInfo: + """Get timezone for an airport.""" + if airport not in AIRPORT_TIMEZONES: + msg = f"Unknown airport timezone: {airport}" + raise ValueError(msg) + return AIRPORT_TIMEZONES[airport] -def generate_fares(rng: random.Random) -> list[dict]: +def generate_fares(rng: random.Random) -> list[Fare]: """Generate random fares for a flight with different fare bundles (Frontier Airlines model).""" fares = [] # Basic fare: no services included basic_price = rng.uniform(*FARE_BASE_PRICES["basic"]["price_range"]) fares.append( - { - "fare_type": "basic", - "price_total": round(basic_price, 2), - "currency": "USD", - "seats_available": rng.randint(5, 15), - "included_services": [], - "checked_bags_included": 0, - } + Fare( + fare_type="basic", + price_total=round(basic_price, 2), + currency="USD", + seats_available=rng.randint(5, 15), + included_services=[], + checked_bags_included=0, + ) ) # Economy bundle: Basic + Carry on, Standard seat selection, Refundability, Change/cancel fee waived economy_price = rng.uniform(*FARE_BASE_PRICES["economy"]["price_range"]) fares.append( - { - "fare_type": "economy", - "price_total": round(economy_price, 2), - "currency": "USD", - "seats_available": rng.randint(3, 12), - "included_services": ["carry_on", "standard_seat_selection", "refundability", "change_cancel_fee_waived"], - "checked_bags_included": 0, - } + Fare( + fare_type="economy", + price_total=round(economy_price, 2), + currency="USD", + seats_available=rng.randint(3, 12), + included_services=["carry_on", "standard_seat_selection", "refundability", "change_cancel_fee_waived"], + checked_bags_included=0, + ) ) # Premium bundle: Economy + Premium seat selection + Priority Boarding premium_price = rng.uniform(*FARE_BASE_PRICES["premium"]["price_range"]) fares.append( - { - "fare_type": "premium", - "price_total": round(premium_price, 2), - "currency": "USD", - "seats_available": rng.randint(2, 8), - "included_services": [ + Fare( + fare_type="premium", + price_total=round(premium_price, 2), + currency="USD", + seats_available=rng.randint(2, 8), + included_services=[ "carry_on", "standard_seat_selection", "refundability", @@ -227,19 +231,19 @@ def generate_fares(rng: random.Random) -> list[dict]: "premium_seat_selection", "priority_boarding", ], - "checked_bags_included": 0, - } + checked_bags_included=0, + ) ) # Business bundle: Premium + 2 checked bags + UpFront Plus Seating business_price = rng.uniform(*FARE_BASE_PRICES["business"]["price_range"]) fares.append( - { - "fare_type": "business", - "price_total": round(business_price, 2), - "currency": "USD", - "seats_available": rng.randint(1, 4), - "included_services": [ + Fare( + fare_type="business", + price_total=round(business_price, 2), + currency="USD", + seats_available=rng.randint(1, 4), + included_services=[ "carry_on", "standard_seat_selection", "refundability", @@ -248,70 +252,70 @@ def generate_fares(rng: random.Random) -> list[dict]: "priority_boarding", "upfront_plus_seating", ], - "checked_bags_included": 2, - } + checked_bags_included=2, + ) ) return fares -def generate_add_ons(rng: random.Random) -> list[dict]: +def generate_add_ons(rng: random.Random) -> list[ServiceAddOnOption]: """Generate available add-on services for a flight.""" return [ - { - "service_type": "checked_bag", - "price": round(rng.uniform(30, 40), 2), - "currency": "USD", - "description": "One checked bag (up to 50 lbs, 62 linear inches)", - }, - { - "service_type": "carry_on", - "price": round(rng.uniform(20, 30), 2), - "currency": "USD", - "description": "One carry-on bag (personal item included)", - }, - { - "service_type": "standard_seat_selection", - "price": round(rng.uniform(10, 25), 2), - "currency": "USD", - "description": "Select a standard seat in advance", - }, - { - "service_type": "premium_seat_selection", - "price": round(rng.uniform(25, 45), 2), - "currency": "USD", - "description": "Select a stretch seat with extra legroom", - }, - { - "service_type": "upfront_plus_seating", - "price": round(rng.uniform(50, 100), 2), - "currency": "USD", - "description": "UpFront Plus seating in first two rows with guaranteed empty middle seat", - }, - { - "service_type": "priority_boarding", - "price": round(rng.uniform(8, 15), 2), - "currency": "USD", - "description": "Priority boarding with overhead bin space", - }, - { - "service_type": "travel_insurance", - "price": round(rng.uniform(15, 30), 2), - "currency": "USD", - "description": "Trip protection insurance", - }, - { - "service_type": "refundability", - "price": round(rng.uniform(30, 60), 2), - "currency": "USD", - "description": "Add refundability to your booking", - }, - { - "service_type": "change_cancel_fee_waived", - "price": round(rng.uniform(20, 40), 2), - "currency": "USD", - "description": "Waive change and cancel fees", - }, + ServiceAddOnOption( + service_type="checked_bag", + price=round(rng.uniform(30, 40), 2), + currency="USD", + description="One checked bag (up to 50 lbs, 62 linear inches)", + ), + ServiceAddOnOption( + service_type="carry_on", + price=round(rng.uniform(20, 30), 2), + currency="USD", + description="One carry-on bag (personal item included)", + ), + ServiceAddOnOption( + service_type="standard_seat_selection", + price=round(rng.uniform(10, 25), 2), + currency="USD", + description="Select a standard seat in advance", + ), + ServiceAddOnOption( + service_type="premium_seat_selection", + price=round(rng.uniform(25, 45), 2), + currency="USD", + description="Select a stretch seat with extra legroom", + ), + ServiceAddOnOption( + service_type="upfront_plus_seating", + price=round(rng.uniform(50, 100), 2), + currency="USD", + description="UpFront Plus seating in first two rows with guaranteed empty middle seat", + ), + ServiceAddOnOption( + service_type="priority_boarding", + price=round(rng.uniform(8, 15), 2), + currency="USD", + description="Priority boarding with overhead bin space", + ), + ServiceAddOnOption( + service_type="travel_insurance", + price=round(rng.uniform(15, 30), 2), + currency="USD", + description="Trip protection insurance", + ), + ServiceAddOnOption( + service_type="refundability", + price=round(rng.uniform(30, 60), 2), + currency="USD", + description="Add refundability to your booking", + ), + ServiceAddOnOption( + service_type="change_cancel_fee_waived", + price=round(rng.uniform(20, 40), 2), + currency="USD", + description="Waive change and cancel fees", + ), ] @@ -327,7 +331,7 @@ def generate_direct_flights( num_days: int = 8, origin_airports: list[str] | None = None, dest_airports: list[str] | None = None, -) -> list[dict]: +) -> list[Flight]: """Generate direct flights from origin airports to destination airports.""" if origin_airports is None: origin_airports = SF_AIRPORTS @@ -352,30 +356,29 @@ def generate_direct_flights( carrier_code = CARRIER_CODE - departure_time = date.replace(hour=hour, minute=minute, second=0, microsecond=0) + # Create timezone-aware departure time + origin_tz = get_airport_timezone(origin) + departure_time = date.replace(hour=hour, minute=minute, second=0, microsecond=0, tzinfo=origin_tz) # Calculate arrival time duration = get_flight_duration(origin, destination) - arrival_time = departure_time + timedelta(hours=duration) - - # Adjust for timezone - departure_offset = get_timezone_offset(origin) - arrival_offset = get_timezone_offset(destination) - - departure_str = departure_time.strftime(f"%Y-%m-%dT%H:%M:00{departure_offset:+03d}:00") - arrival_str = arrival_time.strftime(f"%Y-%m-%dT%H:%M:00{arrival_offset:+03d}:00") - - flight = { - "id": generate_flight_id(origin, destination, departure_time, carrier_code), - "origin": origin, - "destination": destination, - "departure": departure_str, - "arrival": arrival_str, - "flight_number": f"{carrier_code} {rng.randint(100, 999)}", - "carrier": carrier_code, - "fares": generate_fares(rng), - "add_ons": generate_add_ons(rng), - } + arrival_time_naive = departure_time + timedelta(hours=duration) + + # Convert to destination timezone + dest_tz = get_airport_timezone(destination) + arrival_time = arrival_time_naive.astimezone(dest_tz) + + flight = Flight( + id=generate_flight_id(origin, destination, departure_time, carrier_code), + origin=origin, + destination=destination, + departure=departure_time, + arrival=arrival_time, + flight_number=f"{carrier_code} {rng.randint(100, 999)}", + carrier=carrier_code, + fares=generate_fares(rng), + add_ons=generate_add_ons(rng), + ) flights.append(flight) @@ -388,7 +391,7 @@ def generate_connecting_flights( num_days: int = 8, origin_airports: list[str] | None = None, dest_airports: list[str] | None = None, -) -> list[dict]: +) -> list[Flight]: """Generate connecting flights from origin airports to destination airports via hub airports.""" if origin_airports is None: origin_airports = SF_AIRPORTS @@ -415,10 +418,17 @@ def generate_connecting_flights( # First leg: Origin -> Hub hour1 = rng.randint(6, 18) minute1 = rng.choice([0, 15, 30, 45]) - departure_time_leg1 = date.replace(hour=hour1, minute=minute1, second=0, microsecond=0) + + origin_tz = get_airport_timezone(origin) + departure_time_leg1 = date.replace( + hour=hour1, minute=minute1, second=0, microsecond=0, tzinfo=origin_tz + ) duration1 = get_flight_duration(origin, hub) - arrival_time_leg1 = departure_time_leg1 + timedelta(hours=duration1) + arrival_time_leg1_naive = departure_time_leg1 + timedelta(hours=duration1) + + hub_tz = get_airport_timezone(hub) + arrival_time_leg1 = arrival_time_leg1_naive.astimezone(hub_tz) # Layover: 45 minutes to 3 hours layover_hours = rng.choice([0.75, 1.0, 1.5, 2.0, 2.5, 3.0]) @@ -426,95 +436,68 @@ def generate_connecting_flights( # Second leg: Hub -> Destination duration2 = get_flight_duration(hub, destination) - arrival_time_leg2 = departure_time_leg2 + timedelta(hours=duration2) + arrival_time_leg2_naive = departure_time_leg2 + timedelta(hours=duration2) + + dest_tz = get_airport_timezone(destination) + arrival_time_leg2 = arrival_time_leg2_naive.astimezone(dest_tz) # First leg - departure_offset_leg1 = get_timezone_offset(origin) - arrival_offset_leg1 = get_timezone_offset(hub) - - flight1 = { - "id": generate_flight_id(origin, hub, departure_time_leg1, carrier_code), - "origin": origin, - "destination": hub, - "departure": departure_time_leg1.strftime( - f"%Y-%m-%dT%H:%M:00{departure_offset_leg1:+03d}:00" - ), - "arrival": arrival_time_leg1.strftime(f"%Y-%m-%dT%H:%M:00{arrival_offset_leg1:+03d}:00"), - "flight_number": f"{carrier_code} {rng.randint(100, 999)}", - "carrier": carrier_code, - "fares": generate_fares(rng), - "add_ons": generate_add_ons(rng), - } + flight1 = Flight( + id=generate_flight_id(origin, hub, departure_time_leg1, carrier_code), + origin=origin, + destination=hub, + departure=departure_time_leg1, + arrival=arrival_time_leg1, + flight_number=f"{carrier_code} {rng.randint(100, 999)}", + carrier=carrier_code, + fares=generate_fares(rng), + add_ons=generate_add_ons(rng), + ) # Second leg - departure_offset_leg2 = get_timezone_offset(hub) - arrival_offset_leg2 = get_timezone_offset(destination) - - flight2 = { - "id": generate_flight_id(hub, destination, departure_time_leg2, carrier_code), - "origin": hub, - "destination": destination, - "departure": departure_time_leg2.strftime( - f"%Y-%m-%dT%H:%M:00{departure_offset_leg2:+03d}:00" - ), - "arrival": arrival_time_leg2.strftime(f"%Y-%m-%dT%H:%M:00{arrival_offset_leg2:+03d}:00"), - "flight_number": f"{carrier_code} {rng.randint(100, 999)}", - "carrier": carrier_code, - "fares": generate_fares(rng), - "add_ons": generate_add_ons(rng), - } + flight2 = Flight( + id=generate_flight_id(hub, destination, departure_time_leg2, carrier_code), + origin=hub, + destination=destination, + departure=departure_time_leg2, + arrival=arrival_time_leg2, + flight_number=f"{carrier_code} {rng.randint(100, 999)}", + carrier=carrier_code, + fares=generate_fares(rng), + add_ons=generate_add_ons(rng), + ) flights.extend([flight1, flight2]) return flights -def main(): - """Main function to generate and save flight data.""" +def generate_flight_data() -> list[Flight]: + """ + Generate comprehensive flight data for SF Bay Area to New York routes. + + Returns: + List of Flight objects + """ # Create seeded random number generator for reproducibility - rng = random.Random(42) # noqa: S311 + rng = random.Random(RNG_SEED) # noqa: S311 start_date = datetime.combine(FLIGHT_DATA_DATE, datetime.min.time(), tzinfo=UTC) - num_days = 8 - - print(f"Generating comprehensive flight data starting from {FLIGHT_DATA_DATE.isoformat()}...") # Generate SF -> NYC flights (direct only) - print("Generating direct flights (SF -> NYC)...") direct_flights_sf_to_nyc = generate_direct_flights( - rng, start_date, num_days=num_days, origin_airports=SF_AIRPORTS, dest_airports=NYC_AIRPORTS + rng, start_date, num_days=FLIGHT_DATA_NUM_DAYS, origin_airports=SF_AIRPORTS, dest_airports=NYC_AIRPORTS ) - print(f"Generated {len(direct_flights_sf_to_nyc)} direct flights from SF to NYC") # Generate NYC -> SF flights (direct only) - print("Generating direct flights (NYC -> SF)...") direct_flights_nyc_to_sf = generate_direct_flights( - rng, start_date, num_days=num_days, origin_airports=NYC_AIRPORTS, dest_airports=SF_AIRPORTS + rng, start_date, num_days=FLIGHT_DATA_NUM_DAYS, origin_airports=NYC_AIRPORTS, dest_airports=SF_AIRPORTS ) - print(f"Generated {len(direct_flights_nyc_to_sf)} direct flights from NYC to SF") # Combine all flights (direct only, no transfers) all_flights = direct_flights_sf_to_nyc + direct_flights_nyc_to_sf - # Get the project root (two levels up from scripts/) - project_root = Path(__file__).parent.parent - flights_file = project_root / "data" / "flights.json" - # Sort by departure time - all_flights.sort(key=lambda x: x["departure"]) - - output_data = {"flights": all_flights} - - with open(flights_file, "w") as f: - json.dump(output_data, f, indent=2) - - end_date = start_date + timedelta(days=num_days - 1) - print(f"\n✓ Successfully saved {len(all_flights)} total flights to {flights_file}") - print(f" - Direct flights SF->NYC: {len(direct_flights_sf_to_nyc)}") - print(f" - Direct flights NYC->SF: {len(direct_flights_nyc_to_sf)}") - print(f" - All flights are DIRECT flights from {start_date.date().isoformat()} to {end_date.date().isoformat()}") - print(" - No connecting/transfer flights included") - + all_flights.sort(key=lambda x: x.departure) -if __name__ == "__main__": - main() + return all_flights diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 3beb25d..cdb4a4a 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -1,9 +1,9 @@ -import json import random from datetime import date, datetime, timedelta from typing import Any from airline_agent.constants import DEMO_DATETIME +from airline_agent.data_generation.generate_flights import generate_flight_data from airline_agent.types.booking import ( Booking, BookingStatus, @@ -26,11 +26,10 @@ class BookingTools: - def __init__(self, flights_path: str): - self._flights_path = flights_path - with open(flights_path) as f: - raw = json.load(f) - self._flights: dict[str, Flight] = {x["id"]: Flight.model_validate(x) for x in raw["flights"]} + def __init__(self) -> None: + # Generate flight data dynamically at initialization + flights = generate_flight_data() + self._flights: dict[str, Flight] = {flight.id: flight for flight in flights} # Initialize reservations storage (in-memory only) self._reservations: dict[str, Booking] = {} From 334e1f24faf0bee208d6309755bfc5dd1ea19113 Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 15:35:13 -0800 Subject: [PATCH 19/26] Clean up instructions --- .../backend/services/airline_chat.py | 17 ++--------------- src/airline_agent/constants.py | 6 +++++- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/src/airline_agent/backend/services/airline_chat.py b/src/airline_agent/backend/services/airline_chat.py index 166f741..94365a8 100644 --- a/src/airline_agent/backend/services/airline_chat.py +++ b/src/airline_agent/backend/services/airline_chat.py @@ -4,7 +4,6 @@ import pathlib import uuid from collections.abc import AsyncGenerator -from textwrap import dedent from cleanlab_codex import Client, Project from codex.types import ProjectValidateResponse @@ -47,7 +46,7 @@ get_tools_in_openai_format, run_cleanlab_validation_logging_tools, ) -from airline_agent.constants import AGENT_BASE_INSTRUCTIONS, AGENT_MODEL, DEMO_DATETIME +from airline_agent.constants import AGENT_INSTRUCTIONS, AGENT_MODEL from airline_agent.tools.booking import BookingTools from airline_agent.tools.knowledge_base import KnowledgeBase @@ -57,25 +56,13 @@ logger.setLevel(logging.INFO) -def instructions() -> str: - current_date_str = DEMO_DATETIME.date().isoformat() - current_datetime_str = DEMO_DATETIME.strftime("%Y-%m-%d %H:%M:%S %Z") - return dedent(f""" - {AGENT_BASE_INSTRUCTIONS} - - ## Context: - - Today's date: {current_date_str} - - Current time: {current_datetime_str} - """).strip() - - def create_agent(kb: KnowledgeBase, booking: BookingTools) -> Agent: """Create the airline support agent.""" model = OpenAIChatModel(model_name=AGENT_MODEL, settings=ModelSettings(temperature=0.0)) return Agent( model=model, - instructions=instructions, + instructions=AGENT_INSTRUCTIONS, tools=[ kb.get_article, kb.search, diff --git a/src/airline_agent/constants.py b/src/airline_agent/constants.py index 3577861..5eeee5a 100644 --- a/src/airline_agent/constants.py +++ b/src/airline_agent/constants.py @@ -17,7 +17,7 @@ RAG_CHUNK_OVERLAP = 200 CONTEXT_RETRIEVAL_TOOLS = ["search", "get_article", "list_directory"] AGENT_MODEL = "gpt-4o" -AGENT_BASE_INSTRUCTIONS = """ +AGENT_INSTRUCTIONS = f""" You are an AI customer support agent for Frontier Airlines. You can use tools to access to a knowledge base of articles and documents about the airline's services, policies, and procedures. ## You have access to the following tools: @@ -48,6 +48,10 @@ - When a booking is successfully created, provide the booking ID and confirmation details clearly. - If you book flights, provide the booking ID and summarize the flights booked and total price. - If the user asks about anything unrelated to the airline, politely inform them that you can only assist with airline-related inquiries. + +## Context: +- Today's date: {DEMO_DATETIME.date().isoformat()} +- Current time: {DEMO_DATETIME.strftime("%H:%M:%S %Z")} """.strip() FALLBACK_RESPONSE = "I'm sorry, but I don't have the information you're looking for. Please rephrase the question or contact Frontier Airlines customer support for further assistance." From 4748e402917903bda1baf3e6ed32302f7d0ce9dc Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 15:41:06 -0800 Subject: [PATCH 20/26] Use toolsets --- .../backend/services/airline_chat.py | 15 +-------------- src/airline_agent/tools/booking.py | 19 +++++++++++++++++++ src/airline_agent/tools/knowledge_base.py | 6 ++++++ 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/airline_agent/backend/services/airline_chat.py b/src/airline_agent/backend/services/airline_chat.py index 94365a8..0563069 100644 --- a/src/airline_agent/backend/services/airline_chat.py +++ b/src/airline_agent/backend/services/airline_chat.py @@ -63,20 +63,7 @@ def create_agent(kb: KnowledgeBase, booking: BookingTools) -> Agent: return Agent( model=model, instructions=AGENT_INSTRUCTIONS, - tools=[ - kb.get_article, - kb.search, - kb.list_directory, - booking.search_flights, - booking.get_fare_details, - booking.book_flights, - booking.get_booking, - booking.get_my_bookings, - booking.add_service_to_booking, - booking.check_in, - booking.get_flight_timings, - booking.get_flight_status, - ], + toolsets=[kb.tools, booking.tools], ) diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index cdb4a4a..0d622c9 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -2,6 +2,8 @@ from datetime import date, datetime, timedelta from typing import Any +from pydantic_ai.toolsets import FunctionToolset + from airline_agent.constants import DEMO_DATETIME from airline_agent.data_generation.generate_flights import generate_flight_data from airline_agent.types.booking import ( @@ -566,3 +568,20 @@ def get_flight_status(self, flight_id: str) -> dict[str, Any]: "scheduled_arrival": flight.arrival.isoformat(), "carrier": flight.carrier, } + + @property + def tools(self) -> FunctionToolset: + """Returns a FunctionToolset containing all booking tools.""" + return FunctionToolset( + tools=[ + self.search_flights, + self.get_fare_details, + self.book_flights, + self.get_booking, + self.get_my_bookings, + self.add_service_to_booking, + self.check_in, + self.get_flight_timings, + self.get_flight_status, + ] + ) diff --git a/src/airline_agent/tools/knowledge_base.py b/src/airline_agent/tools/knowledge_base.py index e1988cb..4893a93 100644 --- a/src/airline_agent/tools/knowledge_base.py +++ b/src/airline_agent/tools/knowledge_base.py @@ -4,6 +4,7 @@ from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage from llama_index.embeddings.openai import OpenAIEmbedding # type: ignore[import-untyped] from pydantic_ai import ModelRetry +from pydantic_ai.toolsets import FunctionToolset from airline_agent.constants import RAG_EMBED_MODEL from airline_agent.types.knowledge_base import DirectoryEntry, KBArticle, SearchResult @@ -105,3 +106,8 @@ def list_directory(self, directory: str) -> list[DirectoryEntry]: dir_name = suffix.split("/", 1)[0] entries.add(DirectoryEntry(name=dir_name, kind="directory")) return sorted(entries, key=lambda e: e.name) + + @property + def tools(self) -> FunctionToolset: + """Returns a FunctionToolset containing all knowledge base tools.""" + return FunctionToolset(tools=[self.get_article, self.search, self.list_directory]) From 1284a94c1018f5a8350660172273b2d6693fb63f Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 15:53:14 -0800 Subject: [PATCH 21/26] Remove dead code related to connecting flights --- .../data_generation/generate_flights.py | 181 +----------------- 1 file changed, 5 insertions(+), 176 deletions(-) diff --git a/src/airline_agent/data_generation/generate_flights.py b/src/airline_agent/data_generation/generate_flights.py index f2065a0..fa80721 100644 --- a/src/airline_agent/data_generation/generate_flights.py +++ b/src/airline_agent/data_generation/generate_flights.py @@ -1,6 +1,6 @@ """ Module to generate Frontier Airlines (F9) flight data for SF Bay Area to New York routes. -Includes direct flights and connecting flights with layovers through hub airports. +Includes direct flights only. """ import random @@ -19,9 +19,6 @@ # San Francisco Bay Area airports SF_AIRPORTS = ["SFO", "SJC", "OAK"] -# Common hub airports for layovers between SF and NYC -HUB_AIRPORTS = ["DEN", "ORD", "ATL", "DFW", "LAS", "PHX", "SEA", "IAH", "MSP", "DTW"] - # New York airports NYC_AIRPORTS = ["JFK", "EWR", "LGA"] @@ -43,16 +40,6 @@ ("SFO", "JFK"): 5.5, ("SFO", "EWR"): 5.3, ("SFO", "LGA"): 5.4, - ("SFO", "DEN"): 2.5, - ("SFO", "ORD"): 4.0, - ("SFO", "ATL"): 4.5, - ("SFO", "DFW"): 3.5, - ("SFO", "LAS"): 1.5, - ("SFO", "PHX"): 1.8, - ("SFO", "SEA"): 2.0, - ("SFO", "IAH"): 3.8, - ("SFO", "MSP"): 3.5, - ("SFO", "DTW"): 4.2, } # SJC and OAK are similar to SFO, with slight variations @@ -82,71 +69,11 @@ else: FLIGHT_DURATIONS[(orig, "OAK")] = duration -# Hub to NYC routes -FLIGHT_DURATIONS.update( - { - ("DEN", "JFK"): 3.5, - ("DEN", "EWR"): 3.3, - ("DEN", "LGA"): 3.4, - ("ORD", "JFK"): 2.0, - ("ORD", "EWR"): 2.0, - ("ORD", "LGA"): 2.0, - ("ATL", "JFK"): 2.5, - ("ATL", "EWR"): 2.3, - ("ATL", "LGA"): 2.4, - ("DFW", "JFK"): 3.5, - ("DFW", "EWR"): 3.3, - ("DFW", "LGA"): 3.4, - ("LAS", "JFK"): 5.0, - ("LAS", "EWR"): 4.8, - ("LAS", "LGA"): 4.9, - ("PHX", "JFK"): 4.5, - ("PHX", "EWR"): 4.3, - ("PHX", "LGA"): 4.4, - ("SEA", "JFK"): 5.5, - ("SEA", "EWR"): 5.3, - ("SEA", "LGA"): 5.4, - ("IAH", "JFK"): 3.0, - ("IAH", "EWR"): 2.8, - ("IAH", "LGA"): 2.9, - ("MSP", "JFK"): 2.8, - ("MSP", "EWR"): 2.6, - ("MSP", "LGA"): 2.7, - ("DTW", "JFK"): 1.8, - ("DTW", "EWR"): 1.6, - ("DTW", "LGA"): 1.7, - } -) - -# NYC to Hub routes (reverse of hub to NYC) -for (hub, nyc), duration in list(FLIGHT_DURATIONS.items()): - if hub in HUB_AIRPORTS and nyc in NYC_AIRPORTS: - FLIGHT_DURATIONS[(nyc, hub)] = duration - # NYC to SF routes (reverse of SF to NYC) for (sf, nyc), duration in list(FLIGHT_DURATIONS.items()): if sf in SF_AIRPORTS and nyc in NYC_AIRPORTS: FLIGHT_DURATIONS[(nyc, sf)] = duration -# SF to Hub routes -for sf in SF_AIRPORTS: - for hub in HUB_AIRPORTS: - if (sf, hub) not in FLIGHT_DURATIONS: - # Use SFO duration as base - base_duration = FLIGHT_DURATIONS.get(("SFO", hub), 3.0) - if sf in SF_BAY_AIRPORTS: - if base_duration > SHORT_FLIGHT_THRESHOLD_HOURS: - FLIGHT_DURATIONS[(sf, hub)] = base_duration - DURATION_ADJUSTMENT - else: - FLIGHT_DURATIONS[(sf, hub)] = base_duration - else: - FLIGHT_DURATIONS[(sf, hub)] = base_duration - -# Hub to SF routes (reverse) -for (sf, hub), duration in list(FLIGHT_DURATIONS.items()): - if sf in SF_AIRPORTS and hub in HUB_AIRPORTS: - FLIGHT_DURATIONS[(hub, sf)] = duration - # Timezone mappings for airports AIRPORT_TIMEZONES = { # SF Bay Area @@ -157,17 +84,6 @@ "JFK": ZoneInfo("America/New_York"), "EWR": ZoneInfo("America/New_York"), "LGA": ZoneInfo("America/New_York"), - # Hubs - "DEN": ZoneInfo("America/Denver"), - "ORD": ZoneInfo("America/Chicago"), - "ATL": ZoneInfo("America/New_York"), - "DFW": ZoneInfo("America/Chicago"), - "LAS": ZoneInfo("America/Los_Angeles"), - "PHX": ZoneInfo("America/Phoenix"), - "SEA": ZoneInfo("America/Los_Angeles"), - "IAH": ZoneInfo("America/Chicago"), - "MSP": ZoneInfo("America/Chicago"), - "DTW": ZoneInfo("America/Detroit"), } @@ -385,96 +301,9 @@ def generate_direct_flights( return flights -def generate_connecting_flights( - rng: random.Random, - start_date: datetime, - num_days: int = 8, - origin_airports: list[str] | None = None, - dest_airports: list[str] | None = None, -) -> list[Flight]: - """Generate connecting flights from origin airports to destination airports via hub airports.""" - if origin_airports is None: - origin_airports = SF_AIRPORTS - if dest_airports is None: - dest_airports = NYC_AIRPORTS - - flights = [] - - for day in range(num_days): - date = start_date + timedelta(days=day) - - # Generate comprehensive connecting routes - all combinations - for origin in origin_airports: - for destination in dest_airports: - # Generate multiple connecting routes through different hubs - # Use all hubs to create many transfer options - for hub in HUB_AIRPORTS: - # Generate 1-3 connecting flights per hub per origin-destination pair per day - num_routes = rng.randint(1, 3) - - for _ in range(num_routes): - carrier_code = CARRIER_CODE - - # First leg: Origin -> Hub - hour1 = rng.randint(6, 18) - minute1 = rng.choice([0, 15, 30, 45]) - - origin_tz = get_airport_timezone(origin) - departure_time_leg1 = date.replace( - hour=hour1, minute=minute1, second=0, microsecond=0, tzinfo=origin_tz - ) - - duration1 = get_flight_duration(origin, hub) - arrival_time_leg1_naive = departure_time_leg1 + timedelta(hours=duration1) - - hub_tz = get_airport_timezone(hub) - arrival_time_leg1 = arrival_time_leg1_naive.astimezone(hub_tz) - - # Layover: 45 minutes to 3 hours - layover_hours = rng.choice([0.75, 1.0, 1.5, 2.0, 2.5, 3.0]) - departure_time_leg2 = arrival_time_leg1 + timedelta(hours=layover_hours) - - # Second leg: Hub -> Destination - duration2 = get_flight_duration(hub, destination) - arrival_time_leg2_naive = departure_time_leg2 + timedelta(hours=duration2) - - dest_tz = get_airport_timezone(destination) - arrival_time_leg2 = arrival_time_leg2_naive.astimezone(dest_tz) - - # First leg - flight1 = Flight( - id=generate_flight_id(origin, hub, departure_time_leg1, carrier_code), - origin=origin, - destination=hub, - departure=departure_time_leg1, - arrival=arrival_time_leg1, - flight_number=f"{carrier_code} {rng.randint(100, 999)}", - carrier=carrier_code, - fares=generate_fares(rng), - add_ons=generate_add_ons(rng), - ) - - # Second leg - flight2 = Flight( - id=generate_flight_id(hub, destination, departure_time_leg2, carrier_code), - origin=hub, - destination=destination, - departure=departure_time_leg2, - arrival=arrival_time_leg2, - flight_number=f"{carrier_code} {rng.randint(100, 999)}", - carrier=carrier_code, - fares=generate_fares(rng), - add_ons=generate_add_ons(rng), - ) - - flights.extend([flight1, flight2]) - - return flights - - def generate_flight_data() -> list[Flight]: """ - Generate comprehensive flight data for SF Bay Area to New York routes. + Generate direct flight data for SF Bay Area to New York routes. Returns: List of Flight objects @@ -484,17 +313,17 @@ def generate_flight_data() -> list[Flight]: start_date = datetime.combine(FLIGHT_DATA_DATE, datetime.min.time(), tzinfo=UTC) - # Generate SF -> NYC flights (direct only) + # Generate SF -> NYC flights direct_flights_sf_to_nyc = generate_direct_flights( rng, start_date, num_days=FLIGHT_DATA_NUM_DAYS, origin_airports=SF_AIRPORTS, dest_airports=NYC_AIRPORTS ) - # Generate NYC -> SF flights (direct only) + # Generate NYC -> SF flights direct_flights_nyc_to_sf = generate_direct_flights( rng, start_date, num_days=FLIGHT_DATA_NUM_DAYS, origin_airports=NYC_AIRPORTS, dest_airports=SF_AIRPORTS ) - # Combine all flights (direct only, no transfers) + # Combine all flights all_flights = direct_flights_sf_to_nyc + direct_flights_nyc_to_sf # Sort by departure time From 9a40bf92abb5a707889161f1c15fe22aa02b66f2 Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 16:03:06 -0800 Subject: [PATCH 22/26] Clean up code and slightly improve domain modeling --- .../data_generation/generate_flights.py | 40 ++++--------------- src/airline_agent/tools/booking.py | 1 - 2 files changed, 7 insertions(+), 34 deletions(-) diff --git a/src/airline_agent/data_generation/generate_flights.py b/src/airline_agent/data_generation/generate_flights.py index fa80721..2453484 100644 --- a/src/airline_agent/data_generation/generate_flights.py +++ b/src/airline_agent/data_generation/generate_flights.py @@ -12,9 +12,6 @@ # Constants RNG_SEED = 42 -SHORT_FLIGHT_THRESHOLD_HOURS = 2.0 # Threshold for short flights (hours) -DURATION_ADJUSTMENT = 0.1 # Adjustment for SJC/OAK flights (hours) -SF_BAY_AIRPORTS = {"SJC", "OAK"} # SF Bay Area airports (excluding SFO) # San Francisco Bay Area airports SF_AIRPORTS = ["SFO", "SJC", "OAK"] @@ -24,7 +21,6 @@ # Frontier Airlines only CARRIER_CODE = "F9" -CARRIER_NAME = "Frontier" # Fare bundle base prices (Frontier Airlines style - no separate cabin classes) FARE_BASE_PRICES = { @@ -47,32 +43,16 @@ # SJC routes (slightly shorter than SFO) for (orig, dest), duration in BASE_DURATIONS.items(): if orig == "SFO": - if duration > SHORT_FLIGHT_THRESHOLD_HOURS: - FLIGHT_DURATIONS[("SJC", dest)] = duration - DURATION_ADJUSTMENT - else: - FLIGHT_DURATIONS[("SJC", dest)] = duration - if dest == "SFO": - if duration > SHORT_FLIGHT_THRESHOLD_HOURS: - FLIGHT_DURATIONS[(orig, "SJC")] = duration - DURATION_ADJUSTMENT - else: - FLIGHT_DURATIONS[(orig, "SJC")] = duration + FLIGHT_DURATIONS[("SJC", dest)] = duration - 0.2 # OAK routes (slightly shorter than SFO) for (orig, dest), duration in BASE_DURATIONS.items(): if orig == "SFO": - if duration > SHORT_FLIGHT_THRESHOLD_HOURS: - FLIGHT_DURATIONS[("OAK", dest)] = duration - DURATION_ADJUSTMENT - else: - FLIGHT_DURATIONS[("OAK", dest)] = duration - if dest == "SFO": - if duration > SHORT_FLIGHT_THRESHOLD_HOURS: - FLIGHT_DURATIONS[(orig, "OAK")] = duration - DURATION_ADJUSTMENT - else: - FLIGHT_DURATIONS[(orig, "OAK")] = duration + FLIGHT_DURATIONS[("OAK", dest)] = duration - 0.1 # NYC to SF routes (reverse of SF to NYC) for (sf, nyc), duration in list(FLIGHT_DURATIONS.items()): if sf in SF_AIRPORTS and nyc in NYC_AIRPORTS: - FLIGHT_DURATIONS[(nyc, sf)] = duration + FLIGHT_DURATIONS[(nyc, sf)] = duration + 1 # jet stream # Timezone mappings for airports AIRPORT_TIMEZONES = { @@ -89,15 +69,11 @@ def get_flight_duration(origin: str, destination: str) -> float: """Get flight duration in hours for a route.""" - route = (origin, destination) - return FLIGHT_DURATIONS.get(route, 3.0) # Default 3 hours if not found + return FLIGHT_DURATIONS[(origin, destination)] def get_airport_timezone(airport: str) -> ZoneInfo: """Get timezone for an airport.""" - if airport not in AIRPORT_TIMEZONES: - msg = f"Unknown airport timezone: {airport}" - raise ValueError(msg) return AIRPORT_TIMEZONES[airport] @@ -270,8 +246,6 @@ def generate_direct_flights( hour = rng.randint(6, 22) minute = rng.choice([0, 15, 30, 45]) - carrier_code = CARRIER_CODE - # Create timezone-aware departure time origin_tz = get_airport_timezone(origin) departure_time = date.replace(hour=hour, minute=minute, second=0, microsecond=0, tzinfo=origin_tz) @@ -285,13 +259,13 @@ def generate_direct_flights( arrival_time = arrival_time_naive.astimezone(dest_tz) flight = Flight( - id=generate_flight_id(origin, destination, departure_time, carrier_code), + id=generate_flight_id(origin, destination, departure_time, CARRIER_CODE), origin=origin, destination=destination, departure=departure_time, arrival=arrival_time, - flight_number=f"{carrier_code} {rng.randint(100, 999)}", - carrier=carrier_code, + flight_number=f"{CARRIER_CODE} {rng.randint(100, 999)}", + carrier=CARRIER_CODE, fares=generate_fares(rng), add_ons=generate_add_ons(rng), ) diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 0d622c9..2b1d832 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -19,7 +19,6 @@ ) # Constants for flight status timing (in seconds) -DEPARTURE_PAST_THRESHOLD = -900 # 15 minutes past departure BOARDING_START_THRESHOLD = 900 # 15 minutes until departure ON_TIME_THRESHOLD = 1800 # 30 minutes until departure From f3c550013c989945abdf7d6774c4d40b7ef1ca0c Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 16:11:12 -0800 Subject: [PATCH 23/26] Fix reset --- src/airline_agent/tools/booking.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 2b1d832..3561e2e 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -40,6 +40,9 @@ def __init__(self) -> None: def _reset(self) -> None: """Clear all reservations and reset the random number generator for test isolation.""" + # Regenerate flights to clear any mutations (gates, terminals, status updates) + flights = generate_flight_data() + self._flights = {flight.id: flight for flight in flights} self._reservations = {} self._rng = random.Random(RNG_SEED) # noqa: S311 From ddb7142f30c6efe3e1316b44fb82d2deb073d962 Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 16:16:06 -0800 Subject: [PATCH 24/26] Raise ModelRetry when model/user supplies bad args --- src/airline_agent/tools/booking.py | 49 +++++++++++++++--------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index 3561e2e..f53c0a5 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -2,6 +2,7 @@ from datetime import date, datetime, timedelta from typing import Any +from pydantic_ai import ModelRetry from pydantic_ai.toolsets import FunctionToolset from airline_agent.constants import DEMO_DATETIME @@ -60,9 +61,9 @@ def search_flights(self, origin: str, destination: str, departure_date: str) -> """ try: dep = date.fromisoformat(departure_date) - except Exception as e: + except Exception: # noqa: BLE001 msg = f"Invalid departure_date: {departure_date}" - raise ValueError(msg) from e + raise ModelRetry(msg) from None return [ fl @@ -85,7 +86,7 @@ def get_fare_details(self, flight_id: str, fare_type: str = "basic") -> dict[str """ if flight_id not in self._flights: msg = f"Flight not found: {flight_id}" - raise ValueError(msg) + raise ModelRetry(msg) flight = self._flights[flight_id] @@ -94,7 +95,7 @@ def get_fare_details(self, flight_id: str, fare_type: str = "basic") -> dict[str if not fare: available_fares = [f.fare_type for f in flight.fares] msg = f"Fare '{fare_type}' not available for flight {flight_id}. Available fares: {available_fares}" - raise ValueError(msg) + raise ModelRetry(msg) return { "flight_id": flight_id, @@ -132,7 +133,7 @@ def book_flights( """ if not flight_ids: msg = "At least one flight ID must be provided" - raise ValueError(msg) + raise ModelRetry(msg) now = DEMO_DATETIME # Generate deterministic booking ID using seeded random @@ -144,7 +145,7 @@ def book_flights( for flight_id in flight_ids: if flight_id not in self._flights: msg = f"Flight not found: {flight_id}" - raise ValueError(msg) + raise ModelRetry(msg) flight = self._flights[flight_id] @@ -153,11 +154,11 @@ def book_flights( if not fare: available_fares = [f.fare_type for f in flight.fares] msg = f"Fare '{fare_type}' not available for flight {flight_id}. Available fares: {available_fares}" - raise ValueError(msg) + raise ModelRetry(msg) if fare.seats_available <= 0: msg = f"No seats available for fare '{fare_type}' for flight {flight_id}" - raise ValueError(msg) + raise ModelRetry(msg) flight_bookings.append( FlightBooking( @@ -199,7 +200,7 @@ def get_booking(self, booking_id: str) -> Booking: """ if booking_id not in self._reservations: msg = f"Booking not found: {booking_id}" - raise ValueError(msg) + raise ModelRetry(msg) return self._reservations[booking_id] def get_my_bookings(self) -> list[Booking]: @@ -235,7 +236,7 @@ def add_service_to_booking( """ if booking_id not in self._reservations: msg = f"Booking not found: {booking_id}" - raise ValueError(msg) + raise ModelRetry(msg) booking = self._reservations[booking_id] @@ -244,12 +245,12 @@ def add_service_to_booking( if not flight_booking: available_flights = [fb.flight_id for fb in booking.flights] msg = f"Flight {flight_id} not found in booking {booking_id}. Available flights: {available_flights}" - raise ValueError(msg) + raise ModelRetry(msg) # Get the flight to check available add-ons if flight_id not in self._flights: msg = f"Flight not found: {flight_id}" - raise ValueError(msg) + raise ModelRetry(msg) flight = self._flights[flight_id] @@ -260,7 +261,7 @@ def add_service_to_booking( msg = ( f"Service '{service_type}' not available for flight {flight_id}. Available add-ons: {available_addons}" ) - raise ValueError(msg) + raise ModelRetry(msg) # Check if service is already included in the fare # Special handling for checked_bag (tracked via count, not in included_services) @@ -270,19 +271,19 @@ def add_service_to_booking( f"Checked bag(s) are already included in the {flight_booking.fare_type} fare " f"for flight {flight_id} ({flight_booking.checked_bags_included} bag(s) included)" ) - raise ValueError(msg) + raise ModelRetry(msg) elif service_type in flight_booking.included_services: msg = ( f"Service '{service_type}' is already included in the {flight_booking.fare_type} fare " f"for flight {flight_id}" ) - raise ValueError(msg) + raise ModelRetry(msg) # Check if add-on already exists existing_addon = next((ao for ao in flight_booking.add_ons if ao.service_type == service_type), None) if existing_addon: msg = f"Service '{service_type}' has already been added to flight {flight_id} in this booking" - raise ValueError(msg) + raise ModelRetry(msg) # Add the service add-on now = DEMO_DATETIME @@ -313,7 +314,7 @@ def add_service_to_booking( # For non-seat services, validate that seat parameters weren't provided if seat_preference or seat_assignment: msg = "seat_preference and seat_assignment can only be set for seat selection service types" - raise ValueError(msg) + raise ModelRetry(msg) addon = GenericServiceAddOn( service_type=service_type, @@ -434,28 +435,28 @@ def check_in(self, booking_id: str, flight_id: str) -> Booking: """ if booking_id not in self._reservations: msg = f"Booking not found: {booking_id}" - raise ValueError(msg) + raise ModelRetry(msg) booking = self._reservations[booking_id] if booking.status.status != "confirmed": msg = f"Cannot check in for booking {booking_id}: booking status is {booking.status.status}" - raise ValueError(msg) + raise ModelRetry(msg) # Find the flight in the booking flight_booking = next((fb for fb in booking.flights if fb.flight_id == flight_id), None) if not flight_booking: available_flights = [fb.flight_id for fb in booking.flights] msg = f"Flight {flight_id} not found in booking {booking_id}. Available flights: {available_flights}" - raise ValueError(msg) + raise ModelRetry(msg) if flight_booking.checked_in: msg = f"Already checked in for flight {flight_id} in booking {booking_id}" - raise ValueError(msg) + raise ModelRetry(msg) # Get the flight details if flight_id not in self._flights: msg = f"Flight not found: {flight_id}" - raise ValueError(msg) + raise ModelRetry(msg) flight = self._flights[flight_id] now = DEMO_DATETIME @@ -491,7 +492,7 @@ def get_flight_timings(self, flight_id: str) -> dict[str, Any]: """ if flight_id not in self._flights: msg = f"Flight not found: {flight_id}" - raise ValueError(msg) + raise ModelRetry(msg) flight = self._flights[flight_id] timings = self._calculate_check_in_timings(flight.departure) @@ -531,7 +532,7 @@ def get_flight_status(self, flight_id: str) -> dict[str, Any]: """ if flight_id not in self._flights: msg = f"Flight not found: {flight_id}" - raise ValueError(msg) + raise ModelRetry(msg) flight = self._flights[flight_id] From b67382f89d4fa537b869c6d235482fec572eeb9d Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 16:37:44 -0800 Subject: [PATCH 25/26] Remove some unnecessary docstrings and comments --- .../data_generation/generate_flights.py | 12 -------- src/airline_agent/tools/booking.py | 9 ------ tests/test_booking.py | 29 ------------------- 3 files changed, 50 deletions(-) diff --git a/src/airline_agent/data_generation/generate_flights.py b/src/airline_agent/data_generation/generate_flights.py index 2453484..06835e7 100644 --- a/src/airline_agent/data_generation/generate_flights.py +++ b/src/airline_agent/data_generation/generate_flights.py @@ -68,17 +68,14 @@ def get_flight_duration(origin: str, destination: str) -> float: - """Get flight duration in hours for a route.""" return FLIGHT_DURATIONS[(origin, destination)] def get_airport_timezone(airport: str) -> ZoneInfo: - """Get timezone for an airport.""" return AIRPORT_TIMEZONES[airport] def generate_fares(rng: random.Random) -> list[Fare]: - """Generate random fares for a flight with different fare bundles (Frontier Airlines model).""" fares = [] # Basic fare: no services included @@ -152,7 +149,6 @@ def generate_fares(rng: random.Random) -> list[Fare]: def generate_add_ons(rng: random.Random) -> list[ServiceAddOnOption]: - """Generate available add-on services for a flight.""" return [ ServiceAddOnOption( service_type="checked_bag", @@ -212,7 +208,6 @@ def generate_add_ons(rng: random.Random) -> list[ServiceAddOnOption]: def generate_flight_id(origin: str, destination: str, departure: datetime, carrier: str) -> str: - """Generate a unique flight ID.""" date_str = departure.strftime("%Y-%m-%dT%H:%M") return f"{carrier}-{origin}-{destination}-{date_str}" @@ -224,7 +219,6 @@ def generate_direct_flights( origin_airports: list[str] | None = None, dest_airports: list[str] | None = None, ) -> list[Flight]: - """Generate direct flights from origin airports to destination airports.""" if origin_airports is None: origin_airports = SF_AIRPORTS if dest_airports is None: @@ -276,12 +270,6 @@ def generate_direct_flights( def generate_flight_data() -> list[Flight]: - """ - Generate direct flight data for SF Bay Area to New York routes. - - Returns: - List of Flight objects - """ # Create seeded random number generator for reproducibility rng = random.Random(RNG_SEED) # noqa: S311 diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index f53c0a5..e443233 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -29,14 +29,11 @@ class BookingTools: def __init__(self) -> None: - # Generate flight data dynamically at initialization flights = generate_flight_data() self._flights: dict[str, Flight] = {flight.id: flight for flight in flights} - # Initialize reservations storage (in-memory only) self._reservations: dict[str, Booking] = {} - # Initialize seeded random number generator for deterministic behavior self._rng = random.Random(RNG_SEED) # noqa: S311 def _reset(self) -> None: @@ -285,7 +282,6 @@ def add_service_to_booking( msg = f"Service '{service_type}' has already been added to flight {flight_id} in this booking" raise ModelRetry(msg) - # Add the service add-on now = DEMO_DATETIME # Create appropriate add-on type based on service type @@ -325,10 +321,8 @@ def add_service_to_booking( flight_booking.add_ons.append(addon) - # Update booking timestamp booking.status.updated_at = now - # Save the updated booking in memory self._reservations[booking_id] = booking return booking @@ -472,10 +466,8 @@ def check_in(self, booking_id: str, flight_id: str) -> Booking: flight_booking.checked_in = True flight_booking.checked_in_at = now - # Update booking timestamp booking.status.updated_at = now - # Save changes self._reservations[booking_id] = booking return booking @@ -574,7 +566,6 @@ def get_flight_status(self, flight_id: str) -> dict[str, Any]: @property def tools(self) -> FunctionToolset: - """Returns a FunctionToolset containing all booking tools.""" return FunctionToolset( tools=[ self.search_flights, diff --git a/tests/test_booking.py b/tests/test_booking.py index 29113cf..c915da0 100644 --- a/tests/test_booking.py +++ b/tests/test_booking.py @@ -3,7 +3,6 @@ def test_search_flights() -> None: - """Test searching for flights between SF and NYC.""" agent = Agent(cleanlab_enabled=False) answer, _ = agent.chat("I want to fly from SFO to JFK on November 12, 2025") assert_judge( @@ -16,11 +15,8 @@ def test_search_flights() -> None: def test_get_fare_details() -> None: - """Test getting fare bundle details for a flight.""" agent = Agent(cleanlab_enabled=False) - # First search for flights agent.chat("Show me flights from SFO to EWR on November 12, 2025") - # Then ask about fare details answer, _ = agent.chat("What's included in the economy fare bundle for the first flight?") assert_judge( [ @@ -32,9 +28,7 @@ def test_get_fare_details() -> None: def test_book_single_flight() -> None: - """Test booking a single flight.""" agent = Agent(cleanlab_enabled=False) - # Search and book agent.chat("I need a flight from SFO to JFK on November 12, 2025") answer, _ = agent.chat("Book the first available flight with basic fare") assert_judge( @@ -48,13 +42,9 @@ def test_book_single_flight() -> None: def test_book_round_trip() -> None: - """Test booking a round trip (outbound and return flights).""" agent = Agent(cleanlab_enabled=False) - # Search for outbound agent.chat("Find flights from OAK to LGA on November 13, 2025") - # Search for return agent.chat("Find return flights from LGA to OAK on November 15, 2025") - # Book both answer, _ = agent.chat("Book the first flight for each leg with economy fare") assert_judge( [ @@ -67,12 +57,9 @@ def test_book_round_trip() -> None: def test_retrieve_booking() -> None: - """Test retrieving booking details by booking ID.""" agent = Agent(cleanlab_enabled=False) - # Create a booking first agent.chat("Find a flight from SJC to JFK on November 12, 2025") booking_response, _ = agent.chat("Book the first flight with basic fare") - # Retrieve bookings answer, _ = agent.chat("Show me my bookings") assert_judge( [ @@ -84,12 +71,9 @@ def test_retrieve_booking() -> None: def test_add_service_to_booking() -> None: - """Test adding a service (checked bag) to an existing booking.""" agent = Agent(cleanlab_enabled=False) - # Create a booking agent.chat("Show me flights from SFO to EWR on November 14, 2025") agent.chat("Book the first flight with basic fare") - # Add a service answer, _ = agent.chat("Add a checked bag to my booking") assert_judge( [ @@ -101,12 +85,9 @@ def test_add_service_to_booking() -> None: def test_check_in() -> None: - """Test checking in for a flight.""" agent = Agent(cleanlab_enabled=False) - # Create a booking agent.chat("Find flights from SFO to JFK on November 12, 2025") agent.chat("Book the first available flight") - # Check in answer, _ = agent.chat("Check me in for my flight") assert_judge( [ @@ -118,11 +99,8 @@ def test_check_in() -> None: def test_flight_status() -> None: - """Test getting flight status information.""" agent = Agent(cleanlab_enabled=False) - # Search for a flight first to get context agent.chat("Show me flights from OAK to LGA on November 12, 2025") - # Ask for status answer, _ = agent.chat("What's the status of the first flight?") assert_judge( [ @@ -134,11 +112,8 @@ def test_flight_status() -> None: def test_flight_timings() -> None: - """Test getting flight timing windows (check-in, boarding, etc.).""" agent = Agent(cleanlab_enabled=False) - # Search for a flight agent.chat("Find flights from SJC to EWR on November 13, 2025") - # Ask about timings answer, _ = agent.chat("When does check-in open for the first flight?") assert_judge( [ @@ -150,7 +125,6 @@ def test_flight_timings() -> None: def test_fare_comparison() -> None: - """Test comparing different fare bundles.""" agent = Agent(cleanlab_enabled=False) agent.chat("Show me flights from SFO to JFK on November 12, 2025") answer, _ = agent.chat("What's the difference between basic and premium fare for the first flight?") @@ -164,7 +138,6 @@ def test_fare_comparison() -> None: def test_invalid_route() -> None: - """Test handling of invalid or unavailable routes.""" agent = Agent(cleanlab_enabled=False) answer, _ = agent.chat("Find flights from SFO to Tokyo on November 12, 2025") assert_judge( @@ -177,7 +150,6 @@ def test_invalid_route() -> None: def test_no_date_provided() -> None: - """Test that agent asks for date when searching flights without one.""" agent = Agent(cleanlab_enabled=False) answer, _ = agent.chat("I want to fly from SFO to JFK") assert_judge( @@ -190,7 +162,6 @@ def test_no_date_provided() -> None: def test_no_existing_bookings() -> None: - """Test isolation: verify no bookings exist from previous tests.""" agent = Agent(cleanlab_enabled=False) answer, _ = agent.chat("Show me my bookings") assert_judge( From 8a95a6743054e437ed92a3d70a200f896e9bafb0 Mon Sep 17 00:00:00 2001 From: Anish Athalye Date: Wed, 5 Nov 2025 16:40:22 -0800 Subject: [PATCH 26/26] Remove unnecessary assignments --- src/airline_agent/tools/booking.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/airline_agent/tools/booking.py b/src/airline_agent/tools/booking.py index e443233..f344981 100644 --- a/src/airline_agent/tools/booking.py +++ b/src/airline_agent/tools/booking.py @@ -323,8 +323,6 @@ def add_service_to_booking( booking.status.updated_at = now - self._reservations[booking_id] = booking - return booking def _assign_seat(self, flight_booking: FlightBooking, _flight_id: str) -> str: @@ -468,8 +466,6 @@ def check_in(self, booking_id: str, flight_id: str) -> Booking: booking.status.updated_at = now - self._reservations[booking_id] = booking - return booking def get_flight_timings(self, flight_id: str) -> dict[str, Any]: