-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathUSGS_classification.py
92 lines (72 loc) · 3.54 KB
/
USGS_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from deepforest import model
import pandas as pd
import glob
import comet_ml
from pytorch_lightning.loggers import CometLogger
from src.classification import preprocess_and_train
import hydra
from omegaconf import DictConfig
import os
import torch.nn.functional as F
# Create train test split, split each class into 90% train and 10% test with a minimum of 10 images per class for test and a max of 100
def train_test_split(df, test_size=0.1, min_test_images=10, max_test_images=100):
train_df = pd.DataFrame()
test_df = pd.DataFrame()
for label in df['label'].unique():
class_df = df[df['label'] == label]
test_count = max(min_test_images, int(len(class_df) * test_size))
test_count = min(test_count, max_test_images)
test_class_df = class_df.sample(n=test_count)
train_class_df = class_df.drop(test_class_df.index)
train_df = pd.concat([train_df, train_class_df])
test_df = pd.concat([test_df, test_class_df])
return train_df, test_df
@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
# Override the classification_model config with USGS.yaml
cfg = hydra.compose(config_name="config", overrides=["classification_model=USGS"])
# From the detection script
crop_annotations = glob.glob("/blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/crops/*.csv")
crop_annotations = [pd.read_csv(x) for x in crop_annotations]
crop_annotations = pd.concat(crop_annotations)
# Keep labels with more than 100 images
crop_annotations = crop_annotations.groupby("label").filter(lambda x: len(x) > 100)
# Only keep two word labels
crop_annotations = crop_annotations[crop_annotations["label"].str.contains(" ")]
# Expand bounding boxes by 30 pixels on all sides
crop_annotations["xmin"] -= 30
crop_annotations["ymin"] -= 30
crop_annotations["xmax"] += 30
crop_annotations["ymax"] += 30
#train_df, validation_df = train_test_split(crop_annotations)
train_df = None
validation_df = None
comet_logger = CometLogger(project_name=cfg.comet.project, workspace=cfg.comet.workspace)
trained_model = preprocess_and_train(
train_df=train_df,
validation_df=validation_df,
comet_logger=comet_logger,
**cfg.classification_model
)
comet_id = comet_logger.experiment.id
checkpoint_dir = "/blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/classification/checkpoints/"
trained_model.trainer.save_checkpoint(os.path.join(checkpoint_dir,f"{comet_id}.ckpt"))
# Confirm it can be loaded
if trained_model.trainer.global_rank == 0:
confirmed_load = model.CropModel.load_from_checkpoint(os.path.join(checkpoint_dir,f"{comet_id}.ckpt"), num_classes=trained_model.num_classes)
trained_model.model.eval()
predicted_class = []
predicted_prob = []
for batch in trained_model.val_dataloader():
x, y = batch
outputs = trained_model.model(x)
yhat = F.softmax(outputs, dim=1)
for i in range(len(yhat)):
predicted_class.append(yhat[i].argmax().item())
predicted_prob.append(yhat[i].max().item())
predicted_frame = pd.DataFrame({"predicted_class":predicted_class, "predicted_prob":predicted_prob})
comet_logger.experiment.log_asset_data(predicted_frame, name="predictions.csv")
trained_model.create_trainer()
trained_model.trainer.validate(trained_model)
if __name__ == "__main__":
main()