Skip to content

Commit

Permalink
Deduplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
baskrahmer committed Apr 30, 2024
1 parent 59383ca commit 2e17c74
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 50 deletions.
9 changes: 7 additions & 2 deletions scripts/train_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
from src.main import main

c = Config(
data_analysis=True,
use_wandb=True,
sample_size=10000,
model_name="sentence-transformers/distiluse-base-multilingual-cased-v2",
train_steps=1000,
val_steps=100,
train_steps=10,
val_steps=5,
learning_rate=0.01,
train_batch_size=256,
objective="hybrid",
score_metric="test_all_mae",
)

main(c)
9 changes: 9 additions & 0 deletions src/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import math
import os
from collections import Counter
from typing import Optional

import pandas as pd
from datasets import DatasetDict
from datasets import load_dataset, concatenate_datasets

Expand Down Expand Up @@ -33,3 +35,10 @@ def get_dataset(c: Config, data_path: str, test_size: float, cls_dataset: Option
dataset["train"] = concatenate_datasets([dataset["train"], cls_train])
dataset["test"] = concatenate_datasets([dataset["test"], cls_test])
return dataset


def get_ciqual_data(c: Config):
# TODO deduplicate this with other CIQUAL loading logic
data_dir = os.path.join(os.path.dirname(__file__), "..", "data")
ciqual_path = os.path.join(data_dir, c.ciqual_filename)
return pd.read_csv(ciqual_path)
40 changes: 6 additions & 34 deletions src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,7 @@
import seaborn as sns

from src.config import Config


def plot_ciqual_distribution(ciqual_path):
footprint_scores = pd.read_csv(ciqual_path)["Score unique EF"]
num_bins = 10
plt.hist(footprint_scores, bins=num_bins, edgecolor='black')

# Adding labels and title
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Co2e distribution')
bins = np.linspace(min(footprint_scores), max(footprint_scores), num_bins + 1)
ticks = [(bins[i] + bins[i + 1]) / 2 for i in range(num_bins)]
formatter = ticker.FormatStrFormatter('%.2f')
plt.gca().xaxis.set_major_formatter(formatter)
plt.xticks(ticks)
plt.yscale("log")

# Displaying the plot
plt.show()
from src.data import get_ciqual_data


def plot_lang_label_frequencies(lang_label_frequencies):
Expand Down Expand Up @@ -81,17 +62,8 @@ def plot_lang_label_frequencies(lang_label_frequencies):
plt.show()


def make_data_analysis_report(c: Config, ciqual_path: str, lang_frequencies: dict, label_frequencies: dict,
lang_label_frequencies: dict, output_path: str):
plot_ciqual_distribution(ciqual_path)
plot_lang_label_frequencies(lang_label_frequencies)


from src.config import Config


def plot_ciqual_distribution(ciqual_path, save_path):
footprint_scores = pd.read_csv(ciqual_path)["Score unique EF"]
def plot_ciqual_distribution(c: Config, save_path: str):
footprint_scores = get_ciqual_data(c)["Score unique EF"]
num_bins = 10
plt.hist(footprint_scores, bins=num_bins, edgecolor='black')

Expand Down Expand Up @@ -171,8 +143,8 @@ def save_dict_to_json(data, save_path, filename):
json.dump(data, f)


def make_data_analysis_report(c: Config, ciqual_path: str, lang_frequencies: dict, label_frequencies: dict,
lang_label_frequencies: dict, output_path: str, mlm: bool):
def make_data_analysis_report(c: Config, lang_frequencies: dict, label_frequencies: dict, lang_label_frequencies: dict,
output_path: str, mlm: bool):
# Ensure the save directory exists
os.makedirs(output_path, exist_ok=True)

Expand All @@ -183,5 +155,5 @@ def make_data_analysis_report(c: Config, ciqual_path: str, lang_frequencies: dic
save_dict_to_json(lang_label_frequencies, output_path, 'lang_label_frequencies.json')

# Plotting and saving plots
plot_ciqual_distribution(ciqual_path, output_path)
plot_ciqual_distribution(c, output_path)
plot_lang_label_frequencies(lang_label_frequencies, output_path)
14 changes: 3 additions & 11 deletions src/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import os
from typing import Tuple, Any

import pandas as pd
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase

from src.config import Config
from src.data import get_ciqual_data
from src.plot import make_data_analysis_report


Expand All @@ -27,7 +27,7 @@ def filter_data(c: Config, mlm: bool = False) -> str:
if c.cache_data and os.path.exists(filtered_products_path):
return filtered_products_path

ciqual_data = pd.read_csv(ciqual_path)
ciqual_data = get_ciqual_data(c)
ciqual_to_agb = {str(c): str(a) for c, a in zip(ciqual_data["Code CIQUAL"], ciqual_data["Code AGB"])}
agb_set = set(ciqual_data["Code AGB"])
del ciqual_data
Expand Down Expand Up @@ -73,8 +73,7 @@ def filter_data(c: Config, mlm: bool = False) -> str:
break

if c.data_analysis:
make_data_analysis_report(c, ciqual_path, lang_frequencies, label_frequencies, lang_label_frequencies,
output_path, mlm)
make_data_analysis_report(c, lang_frequencies, label_frequencies, lang_label_frequencies, output_path, mlm)

return filtered_products_path

Expand Down Expand Up @@ -114,10 +113,3 @@ def prepare_inputs(sample: dict, tokenizer: PreTrainedTokenizerBase, tokenizer_k
sample['regressands'] = class_to_co2e[sample['label']]
sample['classes'] = class_to_idx[sample['label']]
return sample


def get_ciqual_data(c: Config):
# TODO deduplicate this with other CIQUAL loading logic
data_dir = os.path.join(os.path.dirname(__file__), "..", "data")
ciqual_path = os.path.join(data_dir, c.ciqual_filename)
return pd.read_csv(ciqual_path)
12 changes: 9 additions & 3 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
from typing import Any
from typing import List

import torch
from lightning import pytorch as pl
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from transformers import DataCollatorForLanguageModeling, DataCollatorWithPadding
from transformers import PreTrainedTokenizerBase
from typing import Any
from typing import List

from src.config import Config
from src.preprocess import get_ciqual_data
from src.data import get_ciqual_data


def get_loggers(c: Config):
Expand Down Expand Up @@ -79,3 +80,8 @@ def get_ciqual_mapping(c: Config):
ciqual_data = get_ciqual_data(c)
class_to_co2e = {str(c): co2 for c, co2 in zip(ciqual_data["Code AGB"], ciqual_data["Score unique EF"])} # TODO
return class_to_co2e


def get_lci_name_mapping(c: Config):
ciqual_data = get_ciqual_data(c)
return {str(c): co2 for c, co2 in zip(ciqual_data["Code AGB"], ciqual_data["LCI Name"])}

0 comments on commit 2e17c74

Please sign in to comment.