diff --git a/challenge/api.py b/challenge/api.py index 22d16f2..e18305a 100644 --- a/challenge/api.py +++ b/challenge/api.py @@ -1,19 +1,13 @@ -import sys +from datetime import datetime, timezone import fastapi import pandas as pd from fastapi import HTTPException -from pydantic import BaseModel +from pydantic import BaseModel, validator from challenge.model import DelayModel - -def print_to_file(whatever: any): - with open("file.txt", "a") as sys.stdout: - print(whatever) - - -valid_opera_values = [ +VALID_OPERA_VALUES = [ "american airlines", "air canada", "air france", @@ -39,24 +33,17 @@ def print_to_file(whatever: any): "lacsa", ] -valid_tipo_vuelo_values = [ +VALID_TIPO_VUELO_VALUES = [ "I", "N", ] -valid_mes_values = range(1, 13) - - -def valid_tipo_vuelo(tipo_vuelo: str) -> bool: - return tipo_vuelo in valid_tipo_vuelo_values +VALID_MES_VALUES = range(1, 13) -def valid_opera(opera: str) -> bool: - return opera in valid_opera_values - - -def valid_mes(mes_value: int) -> bool: - return mes_value in valid_mes_values +app = fastapi.FastAPI() +model = DelayModel() +model.load_model("models") class Flight(BaseModel): @@ -64,43 +51,50 @@ class Flight(BaseModel): TIPOVUELO: str MES: int - -class FlightData(BaseModel): - flights: list[Flight] - - -app = fastapi.FastAPI() -model = DelayModel() -model.load_model("models") - - -def flight_data_to_pandas(flight_data: FlightData) -> pd.DataFrame: - flight_data_dict = {"OPERA": [], "TIPOVUELO": [], "MES": []} - for elem in flight_data.flights: - if not valid_opera(elem.OPERA.lower()): + @validator("OPERA") + def valid_opera(cls, opera_value: str): + if opera_value.lower() not in VALID_OPERA_VALUES: raise HTTPException( status_code=400, detail=( - f"Value for tipo vuelo not valid. Recieved {elem.OPERA}," - f" expected one from {[v for v in valid_opera_values]}" + f"Value for tipo vuelo not valid. Recieved {opera_value}, " + f"expected one from {VALID_OPERA_VALUES}" ), ) - if not valid_tipo_vuelo(elem.TIPOVUELO.capitalize()): + return opera_value + + @validator("TIPOVUELO") + def valid_tipo_vuelo(cls, tipo_vuelo_value: str): + if tipo_vuelo_value.capitalize() not in VALID_TIPO_VUELO_VALUES: raise HTTPException( status_code=400, detail=( - f"Value for tipo vuelo not valid. Recieved {elem.TIPOVUELO}," - f" expected one from {[v for v in valid_tipo_vuelo_values]}" + f"Value for tipo vuelo not valid. Recieved {tipo_vuelo_value}, " + f"expected one from {VALID_TIPO_VUELO_VALUES}" ), ) - if not valid_mes(elem.MES): + return tipo_vuelo_value + + @validator("MES") + def valid_mes(cls, mes_value: int): + if mes_value not in VALID_MES_VALUES: raise HTTPException( status_code=400, detail=( - f"Value for tipo vuelo not valid. Recieved {elem.MES}," - f" expected one from {valid_mes_values}" + f"Value for tipo vuelo not valid. Recieved {mes_value}, " + f"expected one from {VALID_MES_VALUES}" ), ) + return mes_value + + +class FlightData(BaseModel): + flights: list[Flight] + + +def flight_data_to_pandas(flight_data: FlightData) -> pd.DataFrame: + flight_data_dict = {"OPERA": [], "TIPOVUELO": [], "MES": []} + for elem in flight_data.flights: flight_data_dict["OPERA"].append(elem.OPERA) flight_data_dict["TIPOVUELO"].append(elem.TIPOVUELO) flight_data_dict["MES"].append(elem.MES) @@ -108,6 +102,17 @@ def flight_data_to_pandas(flight_data: FlightData) -> pd.DataFrame: return pd.DataFrame(flight_data_dict) +@app.get("/", status_code=200) +async def root() -> dict: + return { + "message": ( + "welcome to the api for predicting flight delay. Use the /health " + "endpoint to get server status, and the /predict endpoint to get your " + "prediction from input data." + ) + } + + @app.get("/health", status_code=200) async def get_health() -> dict: return {"status": "OK"} @@ -115,14 +120,23 @@ async def get_health() -> dict: @app.post("/predict", status_code=200) async def post_predict(flight_data: FlightData) -> dict: - # get data and convert to pandas dataframe - - flight_data_df = flight_data_to_pandas(flight_data) - preprocessed_data = model.preprocess(flight_data_df) - - column_order = model._model.feature_names_in_ - preprocessed_data = preprocessed_data[column_order] - - pred = model.predict(preprocessed_data) - - return {"predict": pred} + try: + # get data and convert to pandas dataframe + flight_data_df = flight_data_to_pandas(flight_data) + preprocessed_data = model.preprocess(flight_data_df) + + # sorts column to feed the model + column_order = model._model.feature_names_in_ + preprocessed_data = preprocessed_data[column_order] + + pred = model.predict(preprocessed_data) + + return {"predict": pred} + except Exception as e: + # there may be exceptions we don't want to send to the clients, so log them in + # an internal file for debugging. Just as a cheap solution. + with open("error_logs.txt", "a") as f: + f.write(f"{datetime.now(timezone.utc)}: encounter error {e}") + raise HTTPException( + status_code=500, detail="Internal server error during prediction" + ) diff --git a/docs/challenge.md b/docs/challenge.md index 58be80c..dffe313 100644 --- a/docs/challenge.md +++ b/docs/challenge.md @@ -59,6 +59,28 @@ number of trees in xgboost). Also, it has the advantage that we can limit ourselves to only one framework (scikit learn), and have less imcompatibility issues when trying to move our model to production. +## Part II API developement + +Developed an api to serve the model's predictions properly. + +There is a welcome message at the root (`/`) entry-point, a health status check +at the `/health` entry-point, and the prediction service at the `/predict` +entry-point. + +This API, expects a directory named `models/` at the level of its execution, +where the model object will look for a `model.pkl` file, which stores a trained +instance of the selected model. + +Notice that the api mostly manages the reception of information, and does little +processing, i.e. convert the input list of flights into a pandas dataframe. + +Also, while on the prediction stage, where an error may occur, I've decided to +not report the error directly to the client, but to log it in an internal file, +and return a 500 error. This is not scalable, it's just an ad-hoc solution to +unwanted information leak to the client side of the api. Lot more information +could be into the log, and could be done with a proper library. But just to +showcase the proper railguard that needs to be there. + ## Part III - Deployment to Cloud A first step for deployin to cloud, is to build a Dockerfile for our application