diff --git a/flash/data/data_utils.py b/flash/data/data_utils.py index 0529884b34..50e6b3bf63 100644 --- a/flash/data/data_utils.py +++ b/flash/data/data_utils.py @@ -3,7 +3,7 @@ import pandas as pd -def labels_from_categorical_csv(csv: str, index_col: str, return_dict: dict = True) -> Union[Dict, List]: +def labels_from_categorical_csv(csv: str, index_col: str, return_dict: bool = True) -> Union[Dict, List]: """ Returns a dictionary with {index_col: label} for each entry in the csv.