Skip to content

Commit 6ebebe1

Browse files
committed
multilabel classification
1 parent f2fd545 commit 6ebebe1

File tree

7 files changed

+454
-3
lines changed

7 files changed

+454
-3
lines changed

Diff for: configs/multilabel.yaml

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
dataloader:
2+
_target_: earthtext.datamodules.chipmultilabel.ChipMultilabelModule
3+
metadata_file: /opt/data/california-worldcover-chips/california-worldcover-chips-osm-multilabels.parquet
4+
embeddings_folder: /opt/data/california-worldcover-chips/embeddings_v0.2
5+
get_strlabels: True
6+
#num_workers: 10
7+
min_ohe_count: 1
8+
9+
model:
10+
_target_: earthtext.models.multilabel.MultilabelModel
11+
input_dim: 768
12+
output_dim: 99
13+
layers_spec: [512, 256, 128]
14+
activation_fn: 'elu'

Diff for: notebooks/models/01 - clay model v0.2 test.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@
257257
"name": "python",
258258
"nbconvert_exporter": "python",
259259
"pygments_lexer": "ipython3",
260-
"version": "3.10.9"
260+
"version": "3.12.1"
261261
}
262262
},
263263
"nbformat": 4,

Diff for: notebooks/models/04 - multilabel classification.ipynb

+350
Large diffs are not rendered by default.

Diff for: src/earthtext/datamodules/components/chipmultilabel.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def __init__(
3838
get_strlabels = False,
3939
get_esawc_proportions = False,
4040
get_chip_id = False,
41-
min_ohe_count = 1
41+
min_ohe_count = 1,
42+
cache_size = -1
4243
):
4344

4445
"""
@@ -67,6 +68,9 @@ def __init__(
6768
self.metadata = self.metadata[chips_exists]
6869
logger.info(f"read {split} split with {len(self.metadata)} chip files (out of {nitems})")
6970

71+
logger.info(f"max cache size is {cache_size}")
72+
self.cache_size = cache_size
73+
self.cache = {}
7074

7175
def prepare_data(self):
7276
"""This is an optional preprocessing step to be defined in each dataloader.
@@ -80,11 +84,14 @@ def __len__(self):
8084

8185

8286
def __repr__(self):
83-
return f"{self.__class__.__name__} {self.split} split with {len(self)} items"
87+
return f"{self.__class__.__name__} {self.split} split with {len(self)} items, in cache {len(self.cache)} items"
8488

8589

8690
def __getitem__(self, idx):
8791

92+
if idx in self.cache.keys():
93+
return self.cache[idx]
94+
8895
r = {}
8996

9097
item = self.metadata.iloc[idx]
@@ -115,6 +122,9 @@ def __getitem__(self, idx):
115122
if self.get_esawc_proportions:
116123
r['esawc_proportions'] = str(item.esawc_proportions)
117124

125+
# store in cache
126+
if self.cache_size == -1 or len(self.cache) < self.cache_size:
127+
self.cache[idx] = r
118128

119129
return r
120130

Diff for: src/earthtext/models/multilabel.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
from torch import nn
3+
import torch.nn.functional as F
4+
import numpy as np
5+
6+
def activation_from_str(activation_str):
7+
if activation_str == 'relu':
8+
return nn.ReLU()
9+
10+
if activation_str == 'elu':
11+
return nn.ELU()
12+
13+
if activation_str == 'sigmoid':
14+
return nn.Sigmoid()
15+
16+
raise ValueError(f"unknown activation function string '{activation_str}'")
17+
18+
class MultilabelModel(nn.Module):
19+
"""
20+
assumes an input of shape [batch_size, h, w, 2, 2]
21+
"""
22+
def __init__(self, input_dim, output_dim, layers_spec = [10], activation_fn='relu'):
23+
super().__init__()
24+
25+
self.input_dim = input_dim
26+
self.output_dim = output_dim
27+
self.layers_spec = layers_spec
28+
self.activation_fn = activation_fn
29+
30+
layers = [
31+
nn.Linear(input_dim, layers_spec[0]),
32+
activation_from_str(activation_fn)
33+
]
34+
35+
for i in range(len(layers_spec)-1):
36+
layers.append(nn.Linear(layers_spec[i], layers_spec[i+1]))
37+
layers.append(activation_from_str(activation_fn))
38+
39+
layers.append(nn.Linear(layers_spec[-1], output_dim))
40+
layers.append(activation_from_str('sigmoid'))
41+
42+
self.layers = nn.Sequential(*layers)
43+
44+
45+
def forward(self, x):
46+
x = self.layers(x)
47+
return x
48+

Diff for: train.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
import random
3+
4+
import dotenv
5+
import hydra
6+
import lightning.pytorch as pl
7+
import numpy as np
8+
import omegaconf
9+
from loguru import logger
10+
from omegaconf import DictConfig
11+
12+
13+
14+
@hydra.main(
15+
version_base="1.1",
16+
config_path="configs",
17+
config_name="train.yaml",
18+
)
19+
20+
def main(config: DictConfig):
21+
print (config)
22+
23+
dataloader = hydra.utils.instantiate(config.dataloader)
24+
25+
print ("XX", dataloader)
26+
27+
28+
if __name__ == "__main__":
29+
main()

0 commit comments

Comments
 (0)