From 2e17c74363e265d48258f662fba3c641b88a46e1 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Tue, 30 Apr 2024 18:41:39 +0200 Subject: [PATCH] Deduplicate code --- scripts/train_script.py | 9 +++++++-- src/data.py | 9 +++++++++ src/plot.py | 40 ++++++---------------------------------- src/preprocess.py | 14 +++----------- src/utils.py | 12 +++++++++--- 5 files changed, 34 insertions(+), 50 deletions(-) diff --git a/scripts/train_script.py b/scripts/train_script.py index ea065ba..8572ca2 100644 --- a/scripts/train_script.py +++ b/scripts/train_script.py @@ -2,11 +2,16 @@ from src.main import main c = Config( + data_analysis=True, use_wandb=True, sample_size=10000, model_name="sentence-transformers/distiluse-base-multilingual-cased-v2", - train_steps=1000, - val_steps=100, + train_steps=10, + val_steps=5, + learning_rate=0.01, + train_batch_size=256, + objective="hybrid", + score_metric="test_all_mae", ) main(c) diff --git a/src/data.py b/src/data.py index 15e2d8a..b377fcc 100644 --- a/src/data.py +++ b/src/data.py @@ -1,7 +1,9 @@ import math +import os from collections import Counter from typing import Optional +import pandas as pd from datasets import DatasetDict from datasets import load_dataset, concatenate_datasets @@ -33,3 +35,10 @@ def get_dataset(c: Config, data_path: str, test_size: float, cls_dataset: Option dataset["train"] = concatenate_datasets([dataset["train"], cls_train]) dataset["test"] = concatenate_datasets([dataset["test"], cls_test]) return dataset + + +def get_ciqual_data(c: Config): + # TODO deduplicate this with other CIQUAL loading logic + data_dir = os.path.join(os.path.dirname(__file__), "..", "data") + ciqual_path = os.path.join(data_dir, c.ciqual_filename) + return pd.read_csv(ciqual_path) diff --git a/src/plot.py b/src/plot.py index 99c0b6f..bf060c0 100644 --- a/src/plot.py +++ b/src/plot.py @@ -8,26 +8,7 @@ import seaborn as sns from src.config import Config - - -def plot_ciqual_distribution(ciqual_path): - footprint_scores = pd.read_csv(ciqual_path)["Score unique EF"] - num_bins = 10 - plt.hist(footprint_scores, bins=num_bins, edgecolor='black') - - # Adding labels and title - plt.xlabel('Value') - plt.ylabel('Frequency') - plt.title('Co2e distribution') - bins = np.linspace(min(footprint_scores), max(footprint_scores), num_bins + 1) - ticks = [(bins[i] + bins[i + 1]) / 2 for i in range(num_bins)] - formatter = ticker.FormatStrFormatter('%.2f') - plt.gca().xaxis.set_major_formatter(formatter) - plt.xticks(ticks) - plt.yscale("log") - - # Displaying the plot - plt.show() +from src.data import get_ciqual_data def plot_lang_label_frequencies(lang_label_frequencies): @@ -81,17 +62,8 @@ def plot_lang_label_frequencies(lang_label_frequencies): plt.show() -def make_data_analysis_report(c: Config, ciqual_path: str, lang_frequencies: dict, label_frequencies: dict, - lang_label_frequencies: dict, output_path: str): - plot_ciqual_distribution(ciqual_path) - plot_lang_label_frequencies(lang_label_frequencies) - - -from src.config import Config - - -def plot_ciqual_distribution(ciqual_path, save_path): - footprint_scores = pd.read_csv(ciqual_path)["Score unique EF"] +def plot_ciqual_distribution(c: Config, save_path: str): + footprint_scores = get_ciqual_data(c)["Score unique EF"] num_bins = 10 plt.hist(footprint_scores, bins=num_bins, edgecolor='black') @@ -171,8 +143,8 @@ def save_dict_to_json(data, save_path, filename): json.dump(data, f) -def make_data_analysis_report(c: Config, ciqual_path: str, lang_frequencies: dict, label_frequencies: dict, - lang_label_frequencies: dict, output_path: str, mlm: bool): +def make_data_analysis_report(c: Config, lang_frequencies: dict, label_frequencies: dict, lang_label_frequencies: dict, + output_path: str, mlm: bool): # Ensure the save directory exists os.makedirs(output_path, exist_ok=True) @@ -183,5 +155,5 @@ def make_data_analysis_report(c: Config, ciqual_path: str, lang_frequencies: dic save_dict_to_json(lang_label_frequencies, output_path, 'lang_label_frequencies.json') # Plotting and saving plots - plot_ciqual_distribution(ciqual_path, output_path) + plot_ciqual_distribution(c, output_path) plot_lang_label_frequencies(lang_label_frequencies, output_path) diff --git a/src/preprocess.py b/src/preprocess.py index ab7f88c..8a22ec0 100644 --- a/src/preprocess.py +++ b/src/preprocess.py @@ -2,11 +2,11 @@ import os from typing import Tuple, Any -import pandas as pd from tqdm import tqdm from transformers import PreTrainedTokenizerBase from src.config import Config +from src.data import get_ciqual_data from src.plot import make_data_analysis_report @@ -27,7 +27,7 @@ def filter_data(c: Config, mlm: bool = False) -> str: if c.cache_data and os.path.exists(filtered_products_path): return filtered_products_path - ciqual_data = pd.read_csv(ciqual_path) + ciqual_data = get_ciqual_data(c) ciqual_to_agb = {str(c): str(a) for c, a in zip(ciqual_data["Code CIQUAL"], ciqual_data["Code AGB"])} agb_set = set(ciqual_data["Code AGB"]) del ciqual_data @@ -73,8 +73,7 @@ def filter_data(c: Config, mlm: bool = False) -> str: break if c.data_analysis: - make_data_analysis_report(c, ciqual_path, lang_frequencies, label_frequencies, lang_label_frequencies, - output_path, mlm) + make_data_analysis_report(c, lang_frequencies, label_frequencies, lang_label_frequencies, output_path, mlm) return filtered_products_path @@ -114,10 +113,3 @@ def prepare_inputs(sample: dict, tokenizer: PreTrainedTokenizerBase, tokenizer_k sample['regressands'] = class_to_co2e[sample['label']] sample['classes'] = class_to_idx[sample['label']] return sample - - -def get_ciqual_data(c: Config): - # TODO deduplicate this with other CIQUAL loading logic - data_dir = os.path.join(os.path.dirname(__file__), "..", "data") - ciqual_path = os.path.join(data_dir, c.ciqual_filename) - return pd.read_csv(ciqual_path) diff --git a/src/utils.py b/src/utils.py index 16b35f7..1ed0488 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,14 +1,15 @@ import os +from typing import Any +from typing import List + import torch from lightning import pytorch as pl from lightning.pytorch.loggers import CSVLogger, WandbLogger from transformers import DataCollatorForLanguageModeling, DataCollatorWithPadding from transformers import PreTrainedTokenizerBase -from typing import Any -from typing import List from src.config import Config -from src.preprocess import get_ciqual_data +from src.data import get_ciqual_data def get_loggers(c: Config): @@ -79,3 +80,8 @@ def get_ciqual_mapping(c: Config): ciqual_data = get_ciqual_data(c) class_to_co2e = {str(c): co2 for c, co2 in zip(ciqual_data["Code AGB"], ciqual_data["Score unique EF"])} # TODO return class_to_co2e + + +def get_lci_name_mapping(c: Config): + ciqual_data = get_ciqual_data(c) + return {str(c): co2 for c, co2 in zip(ciqual_data["Code AGB"], ciqual_data["LCI Name"])}