diff --git a/src/autolabel/dataset_loader.py b/src/autolabel/dataset_loader.py index c884dc7a..824cbde9 100644 --- a/src/autolabel/dataset_loader.py +++ b/src/autolabel/dataset_loader.py @@ -14,155 +14,171 @@ class DatasetLoader: # TODO: add support for reading from SQL databases # TODO: add support for reading and loading datasets in chunks - @staticmethod - def read_csv( + def __init__( + self, + dataset: Union[str, pd.DataFrame], + config: AutolabelConfig, + max_items: int = 0, + start_index: int = 0, + ) -> None: + """DatasetLoader class to read and load datasets. + + Args: + dataset (Union[str, pd.DataFrame]): path to the dataset or the dataframe + config (AutolabelConfig): config object + max_items (int, optional): max number of items to read. Defaults to 0. + start_index (int, optional): start index to read from. Defaults to 0. + """ + self.dataset = dataset + self.config = config + self.max_items = max_items + self.start_index = start_index + + if isinstance(dataset, str): + self._read_file(dataset, config, max_items, start_index) + elif isinstance(dataset, pd.DataFrame): + self._read_dataframe(dataset, config, start_index, max_items) + + def _read_csv( + self, csv_file: str, config: AutolabelConfig, max_items: int = None, start_index: int = 0, - ) -> Tuple[pd.DataFrame, List[Dict], List]: - """Read the csv file and return the dataframe, inputs and gt_labels + ) -> None: + """Read the csv file and sets dat, inputs and gt_labels Args: csv_file (str): path to the csv file config (AutolabelConfig): config object max_items (int, optional): max number of items to read. Defaults to None. start_index (int, optional): start index to read from. Defaults to 0. - - Returns: - Tuple[pd.DataFrame, List[Dict], List]: dataframe, inputs and gt_labels """ logger.debug(f"reading the csv from: {start_index}") delimiter = config.delimiter() label_column = config.label_column() - dat = pd.read_csv(csv_file, sep=delimiter, dtype="str")[start_index:] - dat = dat.astype(str) + self.dat = pd.read_csv(csv_file, sep=delimiter, dtype="str")[start_index:] + self.dat = self.dat.astype(str) if max_items and max_items > 0: - max_items = min(max_items, len(dat)) - dat = dat[:max_items] + max_items = min(max_items, len(self.dat)) + self.dat = self.dat[:max_items] - inputs = dat.to_dict(orient="records") - gt_labels = ( + self.inputs = self.dat.to_dict(orient="records") + self.gt_labels = ( None - if not label_column or not len(inputs) or label_column not in inputs[0] - else dat[label_column].tolist() + if not label_column + or not len(self.inputs) + or label_column not in self.inputs[0] + else self.dat[label_column].tolist() ) - return (dat, inputs, gt_labels) - @staticmethod - def read_dataframe( + def _read_dataframe( + self, df: pd.DataFrame, config: AutolabelConfig, max_items: int = None, start_index: int = 0, - ) -> Tuple[pd.DataFrame, List[Dict], List]: - """Read the csv file and return the dataframe, inputs and gt_labels + ) -> None: + """Read the csv file and sets dat, inputs and gt_labels Args: df (pd.DataFrame): dataframe to read config (AutolabelConfig): config object max_items (int, optional): max number of items to read. Defaults to None. start_index (int, optional): start index to read from. Defaults to 0. - - Returns: - Tuple[pd.DataFrame, List[Dict], List]: dataframe, inputs and gt_labels """ label_column = config.label_column() - dat = df[start_index:].astype(str) + self.dat = df[start_index:].astype(str) if max_items and max_items > 0: - max_items = min(max_items, len(dat)) - dat = dat[:max_items] + max_items = min(max_items, len(self.dat)) + self.dat = self.dat[:max_items] - inputs = dat.to_dict(orient="records") - gt_labels = ( + self.inputs = self.dat.to_dict(orient="records") + self.gt_labels = ( None - if not label_column or not len(inputs) or label_column not in inputs[0] - else dat[label_column].tolist() + if not label_column + or not len(self.inputs) + or label_column not in self.inputs[0] + else self.dat[label_column].tolist() ) - return (dat, inputs, gt_labels) - @staticmethod - def read_jsonl( + def _read_jsonl( + self, jsonl_file: str, config: AutolabelConfig, max_items: int = None, start_index: int = 0, - ) -> Tuple[pd.DataFrame, List[Dict], List]: - """Read the jsonl file and return the dataframe, inputs and gt_labels + ) -> None: + """Read the jsonl file and sets dat, inputs and gt_labels Args: jsonl_file (str): path to the jsonl file config (AutolabelConfig): config object max_items (int, optional): max number of items to read. Defaults to None. start_index (int, optional): start index to read from. Defaults to 0. - - Returns: - Tuple[pd.DataFrame, List[Dict], List]: dataframe, inputs and gt_labels """ logger.debug(f"reading the jsonl from: {start_index}") label_column = config.label_column() - dat = pd.read_json(jsonl_file, lines=True, dtype="str")[start_index:] - dat = dat.astype(str) + self.dat = pd.read_json(jsonl_file, lines=True, dtype="str")[start_index:] + self.dat = self.dat.astype(str) if max_items and max_items > 0: - max_items = min(max_items, len(dat)) - dat = dat[:max_items] + max_items = min(max_items, len(self.dat)) + self.dat = self.dat[:max_items] - inputs = dat.to_dict(orient="records") - gt_labels = ( + self.inputs = self.dat.to_dict(orient="records") + self.gt_labels = ( None - if not label_column or not len(inputs) or label_column not in inputs[0] - else dat[label_column].tolist() + if not label_column + or not len(self.inputs) + or label_column not in self.inputs[0] + else self.dat[label_column].tolist() ) - return (dat, inputs, gt_labels) - @staticmethod - def read_sql( + def _read_sql( self, sql: Union[str, Selectable], connection: str, config: AutolabelConfig, max_items: int = None, start_index: int = 0, - ) -> Tuple[pd.DataFrame, List[Dict], List]: - """Read the sql query and return the dataframe, inputs and gt_labels + ) -> None: + """Read the sql query and sets dat, inputs and gt_labels Args: connection (str): connection string config (AutolabelConfig): config object max_items (int, optional): max number of items to read. Defaults to None. start_index (int, optional): start index to read from. Defaults to 0. - - Returns: - Tuple[pd.DataFrame, List[Dict], List]: dataframe, inputs and gt_labels """ logger.debug(f"reading the sql from: {start_index}") label_column = config.label_column() - dat = pd.read_sql(sql, connection)[start_index:] - dat = dat.astype(str) + self.dat = pd.read_sql(sql, connection)[start_index:] + self.dat = self.dat.astype(str) if max_items and max_items > 0: - max_items = min(max_items, len(dat)) - dat = dat[:max_items] + max_items = min(max_items, len(self.dat)) + self.dat = self.dat[:max_items] - inputs = dat.to_dict(orient="records") - gt_labels = ( + self.inputs = self.dat.to_dict(orient="records") + self.gt_labels = ( None - if not label_column or not len(inputs) or label_column not in inputs[0] - else dat[label_column].tolist() + if not label_column + or not len(self.inputs) + or label_column not in self.inputs[0] + else self.dat[label_column].tolist() ) - return (dat, inputs, gt_labels) - @staticmethod - def read_file( + def _read_file( + self, file: str, config: AutolabelConfig, max_items: int = None, start_index: int = 0, - ) -> Tuple[pd.DataFrame, List[Dict], List]: - """Read the file and return the dataframe, inputs and gt_labels + ) -> None: + """Read the file and sets dat, inputs and gt_labels Args: file (str): path to the file @@ -172,16 +188,13 @@ def read_file( Raises: ValueError: if the file format is not supported - - Returns: - Tuple[pd.DataFrame, List[Dict], List]: dataframe, inputs and gt_labels """ if file.endswith(".csv"): - return DatasetLoader.read_csv( + return self._read_csv( file, config, max_items=max_items, start_index=start_index ) elif file.endswith(".jsonl"): - return DatasetLoader.read_jsonl( + return self._read_jsonl( file, config, max_items=max_items, start_index=start_index ) else: diff --git a/src/autolabel/labeler.py b/src/autolabel/labeler.py index 63799549..8cbd2062 100644 --- a/src/autolabel/labeler.py +++ b/src/autolabel/labeler.py @@ -75,14 +75,8 @@ def run( csv_file_name = ( output_name if output_name else f"{dataset.replace('.csv','')}_labeled.csv" ) - if isinstance(dataset, str): - df, inputs, gt_labels = DatasetLoader.read_file( - dataset, self.config, max_items, start_index - ) - elif isinstance(dataset, pd.DataFrame): - df, inputs, gt_labels = DatasetLoader.read_dataframe( - dataset, self.config, max_items, start_index - ) + + dataset_loader = DatasetLoader(dataset, self.config, max_items, start_index) # Initialize task run and check if it already exists self.task_run = self.db.get_task_run(self.task_object.id, self.dataset.id) @@ -90,7 +84,7 @@ def run( if self.task_run: logger.info("Task run already exists.") self.task_run = self.handle_existing_task_run( - self.task_run, csv_file_name, gt_labels=gt_labels + self.task_run, csv_file_name, gt_labels=dataset_loader.gt_labels ) else: self.task_run = self.db.create_task_run( @@ -102,7 +96,8 @@ def run( # If this dataset config is a string, read the corrresponding csv file if isinstance(seed_examples, str): - _, seed_examples, _ = DatasetLoader.read_csv(seed_examples, self.config) + seed_loader = DatasetLoader(seed_examples, self.config) + seed_examples = seed_loader.inputs # Check explanations are present in data if explanation_column is passed in if ( @@ -115,7 +110,7 @@ def run( ) self.example_selector = ExampleSelectorFactory.initialize_selector( - self.config, seed_examples, df.keys().tolist() + self.config, seed_examples, dataset_loader.dat.keys().tolist() ) num_failures = 0 @@ -123,15 +118,17 @@ def run( cost = 0.0 postfix_dict = {} - indices = range(current_index, len(inputs), self.CHUNK_SIZE) + indices = range(current_index, len(dataset_loader.inputs), self.CHUNK_SIZE) for current_index in track_with_stats( indices, postfix_dict, - total=len(inputs) - current_index, + total=len(dataset_loader.inputs) - current_index, advance=self.CHUNK_SIZE, console=console, ): - chunk = inputs[current_index : current_index + self.CHUNK_SIZE] + chunk = dataset_loader.inputs[ + current_index : current_index + self.CHUNK_SIZE + ] final_prompts = [] for i, input_i in enumerate(chunk): # Fetch few-shot seed examples @@ -206,9 +203,9 @@ def run( self.db.session, self.task_run.id ) llm_labels = [LLMAnnotation(**a.llm_annotation) for a in db_result] - if gt_labels: + if dataset_loader.gt_labels: eval_result = self.task.eval( - llm_labels, gt_labels[: len(llm_labels)] + llm_labels, dataset_loader.gt_labels[: len(llm_labels)] ) for m in eval_result: @@ -232,8 +229,10 @@ def run( llm_labels = [LLMAnnotation(**a.llm_annotation) for a in db_result] eval_result = None # if true labels are provided, evaluate accuracy of predictions - if gt_labels: - eval_result = self.task.eval(llm_labels, gt_labels[: len(llm_labels)]) + if dataset_loader.gt_labels: + eval_result = self.task.eval( + llm_labels, dataset_loader.gt_labels[: len(llm_labels)] + ) table = {} # TODO: serialize and write to file for m in eval_result: @@ -245,7 +244,7 @@ def run( print_table(table, console=console, default_style=METRIC_TABLE_STYLE) # Write output to CSV - output_df = df.copy() + output_df = dataset_loader.dat.copy() output_df[self.config.task_name() + "_llm_labeled_successfully"] = [ l.successfully_labeled for l in llm_labels ] @@ -298,14 +297,7 @@ def plan( dataset: path to a CSV dataset """ - if isinstance(dataset, str): - df, inputs, _ = DatasetLoader.read_file( - dataset, self.config, max_items, start_index - ) - elif isinstance(dataset, pd.DataFrame): - df, inputs, _ = DatasetLoader.read_dataframe( - dataset, self.config, max_items, start_index - ) + dataset_loader = DatasetLoader(dataset, self.config, max_items, start_index) prompt_list = [] total_cost = 0 @@ -315,7 +307,8 @@ def plan( # If this dataset config is a string, read the corrresponding csv file if isinstance(seed_examples, str): - _, seed_examples, _ = DatasetLoader.read_file(seed_examples, self.config) + seed_loader = DatasetLoader(seed_examples, self.config) + seed_examples = seed_loader.inputs # Check explanations are present in data if explanation_column is passed in if ( @@ -328,12 +321,12 @@ def plan( ) self.example_selector = ExampleSelectorFactory.initialize_selector( - self.config, seed_examples, df.keys().tolist() + self.config, seed_examples, dataset_loader.dat.keys().tolist() ) - input_limit = min(len(inputs), 100) + input_limit = min(len(dataset_loader.inputs), 100) for input_i in track( - inputs[:input_limit], + dataset_loader.inputs[:input_limit], description="Generating Prompts...", console=console, ): @@ -349,11 +342,11 @@ def plan( curr_cost = self.llm.get_cost(prompt=final_prompt, label="") total_cost += curr_cost - total_cost = total_cost * (len(inputs) / input_limit) + total_cost = total_cost * (len(dataset_loader.inputs) / input_limit) table = { "Total Estimated Cost": f"${maybe_round(total_cost)}", - "Number of Examples": len(inputs), - "Average cost per example": f"${maybe_round(total_cost / len(inputs))}", + "Number of Examples": len(dataset_loader.inputs), + "Average cost per example": f"${maybe_round(total_cost / len(dataset_loader.inputs))}", } table = {"parameter": list(table.keys()), "value": list(table.values())} print_table(table, show_header=False, console=console, styles=COST_TABLE_STYLES) diff --git a/tests/unit/test_data_loading.py b/tests/unit/test_data_loading.py index 93fe92fc..482797a6 100644 --- a/tests/unit/test_data_loading.py +++ b/tests/unit/test_data_loading.py @@ -9,62 +9,62 @@ def test_read_csv(): agent = LabelingAgent(config=config_path) - data = DatasetLoader.read_csv(csv_path, agent.config) + dataset_loader = DatasetLoader(csv_path, agent.config) # test return types - assert isinstance(data, tuple) - assert isinstance(data[0], DataFrame) - assert isinstance(data[1], list) - assert isinstance(data[2], list) or data[2] is None + assert isinstance(dataset_loader, DatasetLoader) + assert isinstance(dataset_loader.dat, DataFrame) + assert isinstance(dataset_loader.inputs, list) + assert ( + isinstance(dataset_loader.gt_labels, list) or dataset_loader.gt_labels is None + ) # test reading_csv with max_items = 5, start_index = 5 - data_max_5_index_5 = DatasetLoader.read_csv( + dataset_loader_max_5_index_5 = DatasetLoader( csv_path, agent.config, max_items=5, start_index=5 ) - assert data_max_5_index_5[0].shape[0] == 5 - assert data_max_5_index_5[0].iloc[0].equals(data[0].iloc[5]) - assert len(data_max_5_index_5[1]) == 5 - assert len(data_max_5_index_5[2]) == 5 + assert dataset_loader_max_5_index_5.dat.shape[0] == 5 + assert dataset_loader_max_5_index_5.dat.iloc[0].equals(dataset_loader.dat.iloc[5]) + assert len(dataset_loader_max_5_index_5.inputs) == 5 + assert len(dataset_loader_max_5_index_5.gt_labels) == 5 def test_read_dataframe(): agent = LabelingAgent(config=config_path) - df, _, _ = DatasetLoader.read_csv(csv_path, agent.config) - data = DatasetLoader.read_dataframe(df, agent.config) + df = DatasetLoader(csv_path, agent.config).dat + dataset_loader = DatasetLoader(df, agent.config) # test return types - assert isinstance(data, tuple) - assert isinstance(data[0], DataFrame) - assert isinstance(data[1], list) - assert isinstance(data[2], list) or data[2] is None + assert isinstance(dataset_loader, DatasetLoader) + assert isinstance(dataset_loader.dat, DataFrame) + assert isinstance(dataset_loader.inputs, list) + assert ( + isinstance(dataset_loader.gt_labels, list) or dataset_loader.gt_labels is None + ) # confirm data matches - assert df.equals(data[0]) + assert df.equals(dataset_loader.dat) # test loading data with max_items = 5, start_index = 5 - data_max_5_index_5 = DatasetLoader.read_dataframe( + dataset_loader_max_5_index_5 = DatasetLoader( df, agent.config, max_items=5, start_index=5 ) - assert data_max_5_index_5[0].shape[0] == 5 - assert data_max_5_index_5[0].iloc[0].equals(data[0].iloc[5]) - assert len(data_max_5_index_5[1]) == 5 - assert len(data_max_5_index_5[2]) == 5 + assert dataset_loader_max_5_index_5.dat.shape[0] == 5 + assert dataset_loader_max_5_index_5.dat.iloc[0].equals(dataset_loader.dat.iloc[5]) + assert len(dataset_loader_max_5_index_5.inputs) == 5 + assert len(dataset_loader_max_5_index_5.gt_labels) == 5 def test_read_jsonl(): agent = LabelingAgent(config=config_path) - data = DatasetLoader.read_jsonl(jsonl_path, agent.config) + dataset_loader = DatasetLoader(jsonl_path, agent.config) # test return types - assert isinstance(data, tuple) - assert isinstance(data[0], DataFrame) - assert isinstance(data[1], list) - assert isinstance(data[2], list) or data[2] is None + assert isinstance(dataset_loader, DatasetLoader) + assert isinstance(dataset_loader.dat, DataFrame) + assert isinstance(dataset_loader.inputs, list) + assert ( + isinstance(dataset_loader.gt_labels, list) or dataset_loader.gt_labels is None + ) # test reading_csv with max_items = 5, start_index = 5 - data_max_5_index_5 = DatasetLoader.read_jsonl( + dataset_loader_max_5_index_5 = DatasetLoader( jsonl_path, agent.config, max_items=5, start_index=5 ) - assert data_max_5_index_5[0].shape[0] == 5 - assert data_max_5_index_5[0].iloc[0].equals(data[0].iloc[5]) - assert len(data_max_5_index_5[1]) == 5 - assert len(data_max_5_index_5[2]) == 5 - - -def test_read_file(): - agent = LabelingAgent(config=config_path) - csv_data, _, _ = DatasetLoader.read_file(csv_path, agent.config) - jsonl_data, _, _ = DatasetLoader.read_file(jsonl_path, agent.config) + assert dataset_loader_max_5_index_5.dat.shape[0] == 5 + assert dataset_loader_max_5_index_5.dat.iloc[0].equals(dataset_loader.dat.iloc[5]) + assert len(dataset_loader_max_5_index_5.inputs) == 5 + assert len(dataset_loader_max_5_index_5.gt_labels) == 5